Просмотр исходного кода

assistantId is retrieved from thread-relation

Shun Miyazawa 1 год назад
Родитель
Сommit
bf1551ff91

+ 1 - 1
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantSidebar/AiAssistantSidebar.tsx

@@ -190,7 +190,7 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
 
 
       const response = await (async() => {
       const response = await (async() => {
         if (isEditorAssistant) {
         if (isEditorAssistant) {
-          return postMessageForEditorAssistant(currentThreadId_, data.input, '# markdown', selectedAiAssistant?._id);
+          return postMessageForEditorAssistant(currentThreadId_, data.input, '# markdown');
         }
         }
         if (aiAssistantData?._id != null) {
         if (aiAssistantData?._id != null) {
           return postMessageForKnowledgeAssistant(aiAssistantData._id, currentThreadId_, data.input, data.summaryMode);
           return postMessageForKnowledgeAssistant(aiAssistantData._id, currentThreadId_, data.input, data.summaryMode);

+ 2 - 3
apps/app/src/features/openai/client/services/editor-assistant.ts

@@ -11,7 +11,7 @@ import {
 import { handleIfSuccessfullyParsed } from '~/features/openai/utils/handle-if-successfully-parsed';
 import { handleIfSuccessfullyParsed } from '~/features/openai/utils/handle-if-successfully-parsed';
 
 
 interface PostMessage {
 interface PostMessage {
-  (threadId: string, userMessage: string, markdown: string, aiAssistantId?: string): Promise<Response>;
+  (threadId: string, userMessage: string, markdown: string): Promise<Response>;
 }
 }
 interface ProcessMessage {
 interface ProcessMessage {
   (data: unknown, handler: {
   (data: unknown, handler: {
@@ -22,12 +22,11 @@ interface ProcessMessage {
 }
 }
 
 
 export const useEditorAssistant = (): { postMessage: PostMessage, processMessage: ProcessMessage } => {
 export const useEditorAssistant = (): { postMessage: PostMessage, processMessage: ProcessMessage } => {
-  const postMessage: PostMessage = useCallback(async(threadId, userMessage, markdown, aiAssistantId) => {
+  const postMessage: PostMessage = useCallback(async(threadId, userMessage, markdown) => {
     const response = await fetch('/_api/v3/openai/edit', {
     const response = await fetch('/_api/v3/openai/edit', {
       method: 'POST',
       method: 'POST',
       headers: { 'Content-Type': 'application/json' },
       headers: { 'Content-Type': 'application/json' },
       body: JSON.stringify({
       body: JSON.stringify({
-        aiAssistantId,
         threadId,
         threadId,
         userMessage,
         userMessage,
         markdown,
         markdown,

+ 10 - 4
apps/app/src/features/openai/server/routes/edit/index.ts

@@ -1,3 +1,4 @@
+import { getIdStringForRef } from '@growi/core';
 import type { IUserHasId } from '@growi/core/dist/interfaces';
 import type { IUserHasId } from '@growi/core/dist/interfaces';
 import { ErrorV3 } from '@growi/core/dist/models';
 import { ErrorV3 } from '@growi/core/dist/models';
 import type { Request, RequestHandler, Response } from 'express';
 import type { Request, RequestHandler, Response } from 'express';
@@ -17,6 +18,7 @@ import loggerFactory from '~/utils/logger';
 import { LlmEditorAssistantDiffSchema, LlmEditorAssistantMessageSchema } from '../../../interfaces/editor-assistant/llm-response-schemas';
 import { LlmEditorAssistantDiffSchema, LlmEditorAssistantMessageSchema } from '../../../interfaces/editor-assistant/llm-response-schemas';
 import type { SseDetectedDiff, SseFinalized, SseMessage } from '../../../interfaces/editor-assistant/sse-schemas';
 import type { SseDetectedDiff, SseFinalized, SseMessage } from '../../../interfaces/editor-assistant/sse-schemas';
 import { MessageErrorCode } from '../../../interfaces/message-error';
 import { MessageErrorCode } from '../../../interfaces/message-error';
+import ThreadRelationModel from '../../models/thread-relation';
 import { getOrCreateEditorAssistant } from '../../services/assistant';
 import { getOrCreateEditorAssistant } from '../../services/assistant';
 import { openaiClient } from '../../services/client';
 import { openaiClient } from '../../services/client';
 import { LlmResponseStreamProcessor } from '../../services/editor-assistant';
 import { LlmResponseStreamProcessor } from '../../services/editor-assistant';
@@ -41,7 +43,6 @@ const LlmEditorAssistantResponseSchema = z.object({
 type ReqBody = {
 type ReqBody = {
   userMessage: string,
   userMessage: string,
   markdown: string,
   markdown: string,
-  aiAssistantId?: string,
   threadId?: string,
   threadId?: string,
 }
 }
 
 
@@ -74,7 +75,6 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
       .withMessage('markdown must be string')
       .withMessage('markdown must be string')
       .notEmpty()
       .notEmpty()
       .withMessage('markdown must be set'),
       .withMessage('markdown must be set'),
-    body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
   ];
   ];
 
 
@@ -82,7 +82,7 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: Req, res: ApiV3Response) => {
     async(req: Req, res: ApiV3Response) => {
       const {
       const {
-        userMessage, markdown, threadId, aiAssistantId,
+        userMessage, markdown, threadId,
       } = req.body;
       } = req.body;
 
 
       // Parameter check
       // Parameter check
@@ -96,8 +96,14 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
         return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
         return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
       }
       }
 
 
+      const threadRelation = await ThreadRelationModel.findOne({ threadId });
+      if (threadRelation == null) {
+        return res.apiv3Err(new ErrorV3('ThreadRelation not found'), 404);
+      }
+
       // Check if usable
       // Check if usable
-      if (aiAssistantId != null) {
+      if (threadRelation.aiAssistant != null) {
+        const aiAssistantId = getIdStringForRef(threadRelation.aiAssistant);
         const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
         const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
         if (!isAiAssistantUsable) {
         if (!isAiAssistantUsable) {
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);