Просмотр исходного кода

Make "aiAssistantId" and "initialUserMessage" optional for thread creation

Shun Miyazawa 1 год назад
Родитель
Сommit
cca907379f

+ 4 - 2
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantSidebar/AiAssistantSidebar.tsx

@@ -144,7 +144,7 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
       try {
         const res = await apiv3Post<IThreadRelationHasId>('/openai/thread', {
           aiAssistantId: aiAssistantData?._id,
-          initialUserMessage: newUserMessage.content,
+          initialUserMessage: isEditorAssistant ? undefined : newUserMessage.content,
         });
 
         const thread = res.data;
@@ -155,7 +155,9 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
         currentThreadId_ = thread.threadId;
 
         // No need to await because data is not used
-        mutateThreadData();
+        if (!isEditorAssistant) {
+          mutateThreadData();
+        }
       }
       catch (err) {
         logger.error(err.toString());

+ 0 - 1
apps/app/src/features/openai/server/models/thread-relation.ts

@@ -28,7 +28,6 @@ const schema = new Schema<ThreadRelationDocument, ThreadRelationModel>({
   aiAssistant: {
     type: Schema.Types.ObjectId,
     ref: 'AiAssistant',
-    required: true,
   },
   threadId: {
     type: String,

+ 4 - 11
apps/app/src/features/openai/server/routes/thread.ts

@@ -17,8 +17,8 @@ import { certifyAiService } from './middlewares/certify-ai-service';
 const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 
 type ReqBody = {
-  aiAssistantId: string,
-  initialUserMessage: string,
+  aiAssistantId?: string,
+  initialUserMessage?: string,
 }
 
 type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
@@ -29,8 +29,8 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
   const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
 
   const validator: ValidationChain[] = [
-    body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
-    body('initialUserMessage').isString().withMessage('initialUserMessage must be string'),
+    body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
+    body('initialUserMessage').optional().isString().withMessage('initialUserMessage must be string'),
   ];
 
   return [
@@ -44,14 +44,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
 
       try {
         const { aiAssistantId, initialUserMessage } = req.body;
-
-        const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
-        if (!isAiAssistantUsable) {
-          return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
-        }
-
         const thread = await openaiService.createThread(req.user._id, aiAssistantId, initialUserMessage);
-
         return res.apiv3(thread);
       }
       catch (err) {

+ 9 - 7
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -23,14 +23,16 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     // TODO: initialize openaiVectorStoreId property
   }
 
-  async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
-    return this.client.beta.threads.create({
-      tool_resources: {
-        file_search: {
-          vector_store_ids: [vectorStoreId],
+  async createThread(vectorStoreId?: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.create(vectorStoreId != null
+      ? {
+        tool_resources: {
+          file_search: {
+            vector_store_ids: [vectorStoreId],
+          },
         },
-      },
-    });
+      }
+      : undefined);
   }
 
   async updateThread(threadId: string, vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {

+ 1 - 1
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -4,7 +4,7 @@ import type { Uploadable } from 'openai/uploads';
 import type { MessageListParams } from '../../../interfaces/message';
 
 export interface IOpenaiClientDelegator {
-  createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
+  createThread(vectorStoreId?: string): Promise<OpenAI.Beta.Threads.Thread>
   updateThread(threadId: string, vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
   retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread>
   deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted>

+ 9 - 7
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -24,14 +24,16 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     this.client = new OpenAI({ apiKey });
   }
 
-  async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
-    return this.client.beta.threads.create({
-      tool_resources: {
-        file_search: {
-          vector_store_ids: [vectorStoreId],
+  async createThread(vectorStoreId?: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.create(vectorStoreId != null
+      ? {
+        tool_resources: {
+          file_search: {
+            vector_store_ids: [vectorStoreId],
+          },
         },
-      },
-    });
+      }
+      : undefined);
   }
 
   async retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread> {

+ 4 - 5
apps/app/src/features/openai/server/services/openai.ts

@@ -65,7 +65,7 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 
 export interface IOpenaiService {
-  createThread(userId: string, aiAssistantId: string, initialUserMessage: string): Promise<ThreadRelationDocument>;
+  createThread(userId: string, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument>;
   getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
@@ -117,9 +117,7 @@ class OpenaiService implements IOpenaiService {
     return threadTitle;
   }
 
-  async createThread(userId: string, aiAssistantId: string, initialUserMessage: string): Promise<ThreadRelationDocument> {
-    const vectorStoreRelation = await this.getVectorStoreRelationByAiAssistantId(aiAssistantId);
-
+  async createThread(userId: string, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument> {
     let threadTitle: string | null = null;
     if (initialUserMessage != null) {
       try {
@@ -131,7 +129,8 @@ class OpenaiService implements IOpenaiService {
     }
 
     try {
-      const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
+      const vectorStoreRelation = aiAssistantId != null ? await this.getVectorStoreRelationByAiAssistantId(aiAssistantId) : null;
+      const thread = await this.client.createThread(vectorStoreRelation?.vectorStoreId);
       const threadRelation = await ThreadRelationModel.create({
         userId,
         aiAssistant: aiAssistantId,