Yuki Takei 1 year ago
parent
commit
f6cc1d46a8

+ 18 - 3
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -143,11 +143,21 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
 
     // post message
     try {
-      const response = await fetch('/_api/v3/openai/message', {
+      const response = await fetch('/_api/v3/openai/edit', {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify({
-          userMessage: data.input, threadId: currentThreadId_, summaryMode: data.summaryMode, aiAssistantId: aiAssistantData._id,
+          userMessage: data.input,
+          threadId: currentThreadId_,
+          summaryMode: data.summaryMode,
+          aiAssistantId: aiAssistantData._id,
+          markdown: `# :tada: Welcome to GROWI
+
+GROWI is an internal wiki & knowledge base tool for corporations and individuals.
+With GROWI, members can easily share and edit information in a company, university seminar, or circle.
+
+Casually writing down the information you know and editing it together can **reduce tacit knowledge within the team**.
+Let's increase the amount of information shared on a daily base!`,
         }),
       });
 
@@ -193,7 +203,12 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
           const trimedLine = line.trim();
           if (trimedLine.startsWith('data:')) {
             const data = JSON.parse(line.replace('data: ', ''));
-            textValues.push(data.content[0].text.value);
+            if (data.content != null) {
+              textValues.push(data.content[0].text.value);
+            }
+            if (data.editorResponse != null) {
+              console.log('replace editor', { editorResponse: data.editorResponse });
+            }
           }
           else if (trimedLine.startsWith('error:')) {
             const error = JSON.parse(line.replace('error: ', ''));

+ 6 - 1
apps/app/src/features/openai/interfaces/message-error.ts

@@ -6,4 +6,9 @@ export const StreamErrorCode = {
   BUDGET_EXCEEDED: 'budget-exceeded',
 } as const;
 
-export type StreamErrorCode = typeof StreamErrorCode[keyof typeof StreamErrorCode];
+export type StreamErrorCode =
+  | 'PREREQUISITE_FAILED'
+  | 'UNSUPPORTED_MODEL'
+  | 'CONTEXT_LENGTH_EXCEEDED'
+  | 'INTERNAL_ERROR'
+  | 'INVALID_RESPONSE_FORMAT'; // JSON解析エラー用に追加

+ 172 - 10
apps/app/src/features/openai/server/routes/edit.ts

@@ -1,10 +1,16 @@
+import { Readable } from 'stream';
+
 import type { IUserHasId } from '@growi/core/dist/interfaces';
 import { ErrorV3 } from '@growi/core/dist/models';
 import type { Request, RequestHandler, Response } from 'express';
 import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
+import { zodResponseFormat } from 'openai/helpers/zod';
 import type { AssistantStream } from 'openai/lib/AssistantStream';
 import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
+import { parser } from 'stream-json';
+import { streamValues } from 'stream-json/streamers/StreamValues';
+import { z } from 'zod';
 
 import { getOrCreateEditorAssistant } from '~/features/openai/server/services/assistant';
 import type Crowi from '~/server/crowi';
@@ -23,6 +29,26 @@ import { certifyAiService } from './middlewares/certify-ai-service';
 
 const logger = loggerFactory('growi:routes:apiv3:openai:message');
 
+// 差分情報のスキーマ定義
+const EditorAssistantMessageSchema = z.object({
+  message: z.string().describe('A friendly message explaining what changes were made or suggested'),
+});
+
+const EditorAssistantDiffSchema = z.object({
+  start: z.number().int().describe('Zero-based index where replacement should begin'),
+  end: z.number().int().describe('Zero-based index where replacement should end'),
+  text: z.string().describe('The text that should replace the content between start and end positions'),
+});
+
+// 新しいレスポンス形式:
+const EditorAssistantResponseSchema = z.object({
+  contents: z.array(z.union([EditorAssistantMessageSchema, EditorAssistantDiffSchema])),
+}).describe('The response format for the editor assistant');
+
+// 型定義をZodスキーマから抽出
+type EditorAssistantMessage = z.infer<typeof EditorAssistantMessageSchema>;
+type EditorAssistantDiff = z.infer<typeof EditorAssistantDiffSchema>;
+type EditorAssistantResponse = z.infer<typeof EditorAssistantResponseSchema>;
 
 type ReqBody = {
   userMessage: string,
@@ -48,9 +74,9 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
       .withMessage('userMessage must be set'),
     body('markdown')
       .isString()
-      .withMessage('userMessage must be string')
+      .withMessage('markdown must be string')
       .notEmpty()
-      .withMessage('userMessage must be set'),
+      .withMessage('markdown must be set'),
     body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
   ];
@@ -58,7 +84,9 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
   return [
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: Req, res: ApiV3Response) => {
-      const { aiAssistantId, threadId } = req.body;
+      const {
+        userMessage, markdown, aiAssistantId, threadId,
+      } = req.body;
 
       if (threadId == null) {
         return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
@@ -76,22 +104,44 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
 
         const thread = await openaiClient.beta.threads.retrieve(threadId);
 
+        // zodResponseFormatを使用して構造化された出力を得る
         stream = openaiClient.beta.threads.runs.stream(thread.id, {
           assistant_id: assistant.id,
           additional_messages: [
             {
               role: 'assistant',
-              content: '', // TODO: add a message to notify the user that the editing is started
+              content: `You are an Editor Assistant for GROWI, a markdown wiki system.
+              Your task is to help users edit their markdown content based on their requests.
+
+              RESPONSE FORMAT:
+              You must respond with a JSON object in the following format:
+              {
+                "contents": [
+                  { "message": "Your friendly message explaining what changes were made or suggested" },
+                  { "start": 0, "end": 10, "text": "New text 1" },
+                  { "message": "Additional explanation if needed" },
+                  { "start": 20, "end": 30, "text": "New text 2" },
+                  ...more items if needed
+                ]
+              }
+
+              The array should contain:
+              - Objects with a "message" key for explanatory text to the user
+              - Objects with "start", "end", and "text" keys for replacements
+
+              If no changes are needed, include only message objects with explanations.
+              Always provide messages in the same language as the user's request.`,
+            },
+            {
+              role: 'user',
+              content: `Current markdown content:\n\`\`\`markdown\n${markdown}\n\`\`\`\n\nUser request: ${userMessage}`,
             },
-            { role: 'user', content: req.body.userMessage },
           ],
+          response_format: zodResponseFormat(EditorAssistantResponseSchema, 'editor_assistant_response'),
         });
-
       }
       catch (err) {
         logger.error(err);
-
-        // TODO: improve error handling by https://redmine.weseek.co.jp/issues/155004
         return res.status(500).send(err.message);
       }
 
@@ -100,15 +150,63 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
         'Cache-Control': 'no-cache, no-transform',
       });
 
+      // JSON解析用の変数
+      let rawBuffer = ''; // 完全な生の入力データを格納
+
+      // クライアントに送信するためのデータ
+      const messages: string[] = []; // 見つかったメッセージを格納
+      const replacements: EditorAssistantDiff[] = []; // 見つかった差分を格納
+
       const messageDeltaHandler = async(delta: MessageDelta) => {
         const content = delta.content?.[0];
 
-        // If annotation is found
+        // アノテーション処理は同様に行う
         if (content?.type === 'text' && content?.text?.annotations != null) {
           await replaceAnnotationWithPageLink(content, req.user.lang);
         }
 
-        res.write(`data: ${JSON.stringify(delta)}\n\n`);
+        if (content?.type === 'text' && content.text?.value) {
+          const chunk = content.text.value;
+          rawBuffer += chunk;
+
+          // バッファからストリームを作成
+          const bufferStream = Readable.from([rawBuffer]);
+
+          // JSONパーサーを設定
+          const jsonParser = bufferStream.pipe(parser()).pipe(streamValues());
+
+          jsonParser.on('data', ({ value, key }) => {
+            // contentsアレイ内の要素を検出
+            if (Array.isArray(value.contents)) {
+              // 完全なcontentsアレイが見つかった場合
+              value.contents.forEach((item) => {
+                if ('message' in item) {
+                  messages.push(item.message);
+                }
+                else if ('start' in item && 'end' in item && 'text' in item) {
+                  const validDiff = EditorAssistantDiffSchema.safeParse(item);
+                  if (validDiff.success) {
+                    replacements.push(validDiff.data);
+                  }
+                }
+              });
+
+              // 更新をクライアントに送信
+              res.write(`data: ${JSON.stringify({
+                editorResponse: {
+                  message: messages.length > 0 ? messages[messages.length - 1] : '',
+                  replacements,
+                },
+              })}\n\n`);
+            }
+          });
+
+          // 元のデルタも送信
+          res.write(`data: ${JSON.stringify(delta)}\n\n`);
+        }
+        else {
+          res.write(`data: ${JSON.stringify(delta)}\n\n`);
+        }
       };
 
       const sendError = (message: string, code?: StreamErrorCode) => {
@@ -125,14 +223,78 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
           sendError(errorMessage, getStreamErrorCode(errorMessage));
         }
       });
+
       stream.on('messageDelta', messageDeltaHandler);
+
       stream.once('messageDone', () => {
+        // 処理完了時、最終的なレスポンスを送信
+        // 最後の確認としてJSONを解析してみる
+        try {
+          // バッファから完全なJSONを解析
+          const parsedJson = JSON.parse(rawBuffer);
+
+          if (parsedJson?.contents && Array.isArray(parsedJson.contents)) {
+            // 最終的なメッセージと差分を収集
+            const finalMessages: string[] = [];
+            const finalReplacements: EditorAssistantDiff[] = [];
+
+            for (const item of parsedJson.contents) {
+              if ('message' in item) {
+                finalMessages.push(item.message);
+              }
+              else if ('start' in item && 'end' in item && 'text' in item) {
+                const validDiff = EditorAssistantDiffSchema.safeParse(item);
+                if (validDiff.success) {
+                  finalReplacements.push(validDiff.data);
+                }
+              }
+            }
+
+            // 最終レスポンスをクライアントに送信(これまでに部分的に送信したものがあっても、完全なデータを再送)
+            res.write(`data: ${JSON.stringify({
+              editorResponse: {
+                message: finalMessages.length > 0 ? finalMessages[finalMessages.length - 1] : '',
+                replacements: finalReplacements,
+              },
+              isDone: true,
+            })}\n\n`);
+          }
+          else {
+            // 既に部分的に送信したデータがある場合はそれを最終データとして扱う
+            res.write(`data: ${JSON.stringify({
+              editorResponse: {
+                message: messages.length > 0 ? messages[messages.length - 1] : '',
+                replacements,
+              },
+              isDone: true,
+            })}\n\n`);
+          }
+        }
+        catch (e) {
+          logger.error('Failed to parse final JSON response', e);
+          // パース失敗時で既存のデータがある場合はそれを送信
+          if (messages.length > 0 || replacements.length > 0) {
+            res.write(`data: ${JSON.stringify({
+              editorResponse: {
+                message: messages.length > 0 ? messages[messages.length - 1] : '',
+                replacements,
+              },
+              isDone: true,
+            })}\n\n`);
+          }
+          else {
+            sendError('Failed to parse assistant response as JSON', 'INVALID_RESPONSE_FORMAT');
+          }
+        }
+
         stream.off('messageDelta', messageDeltaHandler);
         res.end();
       });
+
       stream.once('error', (err) => {
         logger.error(err);
         stream.off('messageDelta', messageDeltaHandler);
+        sendError('An error occurred while processing your request');
         res.end();
       });
     },