Răsfoiți Sursa

Merge pull request #9764 from weseek/feat/163449-chat-using-selected-ai-assistant-id

feat: Chat using selected aiAssistantId
Yuki Takei 1 an în urmă
părinte
comite
455a521985

+ 6 - 7
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantSidebar/AiAssistantDropdown.tsx

@@ -1,5 +1,5 @@
 
-import React, { useState, useMemo, useCallback } from 'react';
+import React, { useMemo, useCallback } from 'react';
 
 import { useTranslation } from 'react-i18next';
 import {
@@ -14,12 +14,11 @@ import { useSWRxAiAssistants } from '../../../stores/ai-assistant';
 import { getShareScopeIcon } from '../../../utils/get-share-scope-Icon';
 
 type Props = {
-  //
+  selectedAiAssistant?: AiAssistantHasId;
+  onSelect(aiAssistant?: AiAssistantHasId): void
 }
 
-export const AiAssistantDropdown = (props: Props): JSX.Element => {
-  const [selectedAiAssistant, setSelectedAiAssistant] = useState<AiAssistantHasId>();
-
+export const AiAssistantDropdown = ({ selectedAiAssistant, onSelect }: Props): JSX.Element => {
   const { t } = useTranslation();
   const { data: aiAssistantData } = useSWRxAiAssistants();
 
@@ -42,8 +41,8 @@ export const AiAssistantDropdown = (props: Props): JSX.Element => {
   }, []);
 
   const selectAiAssistantHandler = useCallback((aiAssistant?: AiAssistantHasId) => {
-    setSelectedAiAssistant(aiAssistant);
-  }, []);
+    onSelect(aiAssistant);
+  }, [onSelect]);
 
   return (
     <UncontrolledDropdown>

+ 11 - 3
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantSidebar/AiAssistantSidebar.tsx

@@ -66,6 +66,7 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
   const [generatingAnswerMessage, setGeneratingAnswerMessage] = useState<Message>();
   const [errorMessage, setErrorMessage] = useState<string | undefined>();
   const [isErrorDetailCollapsed, setIsErrorDetailCollapsed] = useState<boolean>(false);
+  const [selectedAiAssistant, setSelectedAiAssistant] = useState<AiAssistantHasId>();
 
   const { t } = useTranslation();
   const { data: growiCloudUri } = useGrowiCloudUri();
@@ -159,7 +160,7 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
     if (currentThreadId_ == null) {
       try {
         const res = await apiv3Post<IThreadRelationHasId>('/openai/thread', {
-          aiAssistantId: aiAssistantData?._id,
+          aiAssistantId: isEditorAssistant ? selectedAiAssistant?._id : aiAssistantData?._id,
           initialUserMessage: isEditorAssistant ? undefined : newUserMessage.content,
         });
 
@@ -292,7 +293,7 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
     }
 
   // eslint-disable-next-line max-len
-  }, [isGenerating, messageLogs, form, currentThreadId, aiAssistantData?._id, isEditorAssistant, mutateThreadData, t, postMessageForEditorAssistant, postMessageForKnowledgeAssistant, processMessageForKnowledgeAssistant, processMessageForEditorAssistant, growiCloudUri]);
+  }, [isGenerating, messageLogs, form, currentThreadId, aiAssistantData?._id, isEditorAssistant, mutateThreadData, t, postMessageForEditorAssistant, selectedAiAssistant?._id, postMessageForKnowledgeAssistant, processMessageForKnowledgeAssistant, processMessageForEditorAssistant, growiCloudUri]);
 
   const keyDownHandler = (event: KeyboardEvent<HTMLTextAreaElement>) => {
     if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {
@@ -312,6 +313,10 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
     // todo: implement
   }, []);
 
+  const selectAiAssistantHandler = useCallback((aiAssistant?: AiAssistantHasId) => {
+    setSelectedAiAssistant(aiAssistant);
+  }, []);
+
   return (
     <>
       <div className="d-flex flex-column vh-100">
@@ -361,7 +366,10 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
                 ? (
                   <>
                     <div className="py-2">
-                      <AiAssistantDropdown />
+                      <AiAssistantDropdown
+                        selectedAiAssistant={selectedAiAssistant}
+                        onSelect={selectAiAssistantHandler}
+                      />
                     </div>
                     <QuickMenuList
                       onClick={clickQuickMenuHandler}

+ 2 - 3
apps/app/src/features/openai/client/services/editor-assistant.ts

@@ -11,7 +11,7 @@ import {
 import { handleIfSuccessfullyParsed } from '~/features/openai/utils/handle-if-successfully-parsed';
 
 interface PostMessage {
-  (threadId: string, userMessage: string, markdown: string, aiAssistantId?: string): Promise<Response>;
+  (threadId: string, userMessage: string, markdown: string): Promise<Response>;
 }
 interface ProcessMessage {
   (data: unknown, handler: {
@@ -22,12 +22,11 @@ interface ProcessMessage {
 }
 
 export const useEditorAssistant = (): { postMessage: PostMessage, processMessage: ProcessMessage } => {
-  const postMessage: PostMessage = useCallback(async(threadId, userMessage, markdown, aiAssistantId) => {
+  const postMessage: PostMessage = useCallback(async(threadId, userMessage, markdown) => {
     const response = await fetch('/_api/v3/openai/edit', {
       method: 'POST',
       headers: { 'Content-Type': 'application/json' },
       body: JSON.stringify({
-        aiAssistantId,
         threadId,
         userMessage,
         markdown,

+ 19 - 3
apps/app/src/features/openai/server/routes/edit/index.ts

@@ -1,3 +1,4 @@
+import { getIdStringForRef } from '@growi/core';
 import type { IUserHasId } from '@growi/core/dist/interfaces';
 import { ErrorV3 } from '@growi/core/dist/models';
 import type { Request, RequestHandler, Response } from 'express';
@@ -17,6 +18,7 @@ import loggerFactory from '~/utils/logger';
 import { LlmEditorAssistantDiffSchema, LlmEditorAssistantMessageSchema } from '../../../interfaces/editor-assistant/llm-response-schemas';
 import type { SseDetectedDiff, SseFinalized, SseMessage } from '../../../interfaces/editor-assistant/sse-schemas';
 import { MessageErrorCode } from '../../../interfaces/message-error';
+import ThreadRelationModel from '../../models/thread-relation';
 import { getOrCreateEditorAssistant } from '../../services/assistant';
 import { openaiClient } from '../../services/client';
 import { LlmResponseStreamProcessor } from '../../services/editor-assistant';
@@ -41,7 +43,6 @@ const LlmEditorAssistantResponseSchema = z.object({
 type ReqBody = {
   userMessage: string,
   markdown: string,
-  aiAssistantId?: string,
   threadId?: string,
 }
 
@@ -74,14 +75,15 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
       .withMessage('markdown must be string')
       .notEmpty()
       .withMessage('markdown must be set'),
-    body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
   ];
 
   return [
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: Req, res: ApiV3Response) => {
-      const { userMessage, markdown, threadId } = req.body;
+      const {
+        userMessage, markdown, threadId,
+      } = req.body;
 
       // Parameter check
       if (threadId == null) {
@@ -94,6 +96,20 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
         return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
       }
 
+      const threadRelation = await ThreadRelationModel.findOne({ threadId });
+      if (threadRelation == null) {
+        return res.apiv3Err(new ErrorV3('ThreadRelation not found'), 404);
+      }
+
+      // Check if usable
+      if (threadRelation.aiAssistant != null) {
+        const aiAssistantId = getIdStringForRef(threadRelation.aiAssistant);
+        const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
+        if (!isAiAssistantUsable) {
+          return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
+        }
+      }
+
       // Initialize SSE helper and stream processor
       const sseHelper = new SseHelper(res);
       const streamProcessor = new LlmResponseStreamProcessor({