Browse Source

Add validation to check AiAssistant availability

Shun Miyazawa 1 year ago
parent
commit
03e686aab8

+ 3 - 1
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -135,7 +135,9 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
       const response = await fetch('/_api/v3/openai/message', {
       const response = await fetch('/_api/v3/openai/message', {
         method: 'POST',
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         headers: { 'Content-Type': 'application/json' },
-        body: JSON.stringify({ userMessage: data.input, threadId: currentThreadId, summaryMode: data.summaryMode }),
+        body: JSON.stringify({
+          userMessage: data.input, threadId: currentThreadId, summaryMode: data.summaryMode, aiAssistantId,
+        }),
       });
       });
 
 
       if (!response.ok) {
       if (!response.ok) {

+ 14 - 2
apps/app/src/features/openai/server/routes/message.ts

@@ -16,6 +16,7 @@ import loggerFactory from '~/utils/logger';
 import { MessageErrorCode, type StreamErrorCode } from '../../interfaces/message-error';
 import { MessageErrorCode, type StreamErrorCode } from '../../interfaces/message-error';
 import { openaiClient } from '../services/client';
 import { openaiClient } from '../services/client';
 import { getStreamErrorCode } from '../services/getStreamErrorCode';
 import { getStreamErrorCode } from '../services/getStreamErrorCode';
+import { getOpenaiService } from '../services/openai';
 import { replaceAnnotationWithPageLink } from '../services/replace-annotation-with-page-link';
 import { replaceAnnotationWithPageLink } from '../services/replace-annotation-with-page-link';
 
 
 import { certifyAiService } from './middlewares/certify-ai-service';
 import { certifyAiService } from './middlewares/certify-ai-service';
@@ -25,6 +26,7 @@ const logger = loggerFactory('growi:routes:apiv3:openai:message');
 
 
 type ReqBody = {
 type ReqBody = {
   userMessage: string,
   userMessage: string,
+  aiAssistantId: string,
   threadId?: string,
   threadId?: string,
   summaryMode?: boolean,
   summaryMode?: boolean,
 }
 }
@@ -44,19 +46,29 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
       .withMessage('userMessage must be string')
       .withMessage('userMessage must be string')
       .notEmpty()
       .notEmpty()
       .withMessage('userMessage must be set'),
       .withMessage('userMessage must be set'),
+    body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
   ];
   ];
 
 
   return [
   return [
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: Req, res: ApiV3Response) => {
     async(req: Req, res: ApiV3Response) => {
-
-      const threadId = req.body.threadId;
+      const { aiAssistantId, threadId } = req.body;
 
 
       if (threadId == null) {
       if (threadId == null) {
         return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
         return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
       }
       }
 
 
+      const openaiService = getOpenaiService();
+      if (openaiService == null) {
+        return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
+      }
+
+      const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
+      if (!isAiAssistantUsable) {
+        return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
+      }
+
       let stream: AssistantStream;
       let stream: AssistantStream;
 
 
       try {
       try {