Yuki Takei 1 an în urmă
părinte
comite
a98b1b6aea
1 a modificat fișierele cu 174 adăugiri și 156 ștergeri
  1. 174 156
      apps/app/src/features/openai/server/routes/edit.ts

+ 174 - 156
apps/app/src/features/openai/server/routes/edit.ts

@@ -6,7 +6,6 @@ import type { Request, RequestHandler, Response } from 'express';
 import type { ValidationChain } from 'express-validator';
 import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
 import { body } from 'express-validator';
 import { zodResponseFormat } from 'openai/helpers/zod';
 import { zodResponseFormat } from 'openai/helpers/zod';
-import type { AssistantStream } from 'openai/lib/AssistantStream';
 import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
 import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
 import { parser } from 'stream-json';
 import { parser } from 'stream-json';
 import { streamValues } from 'stream-json/streamers/StreamValues';
 import { streamValues } from 'stream-json/streamers/StreamValues';
@@ -29,7 +28,7 @@ import { certifyAiService } from './middlewares/certify-ai-service';
 
 
 const logger = loggerFactory('growi:routes:apiv3:openai:message');
 const logger = loggerFactory('growi:routes:apiv3:openai:message');
 
 
-// 差分情報のスキーマ定義
+// スキーマ定義
 const EditorAssistantMessageSchema = z.object({
 const EditorAssistantMessageSchema = z.object({
   message: z.string().describe('A friendly message explaining what changes were made or suggested'),
   message: z.string().describe('A friendly message explaining what changes were made or suggested'),
 });
 });
@@ -40,15 +39,13 @@ const EditorAssistantDiffSchema = z.object({
   text: z.string().describe('The text that should replace the content between start and end positions'),
   text: z.string().describe('The text that should replace the content between start and end positions'),
 });
 });
 
 
-// 新しいレスポンス形式:
 const EditorAssistantResponseSchema = z.object({
 const EditorAssistantResponseSchema = z.object({
   contents: z.array(z.union([EditorAssistantMessageSchema, EditorAssistantDiffSchema])),
   contents: z.array(z.union([EditorAssistantMessageSchema, EditorAssistantDiffSchema])),
 }).describe('The response format for the editor assistant');
 }).describe('The response format for the editor assistant');
 
 
-// 型定義をZodスキーマから抽出
+// 型定義
 type EditorAssistantMessage = z.infer<typeof EditorAssistantMessageSchema>;
 type EditorAssistantMessage = z.infer<typeof EditorAssistantMessageSchema>;
 type EditorAssistantDiff = z.infer<typeof EditorAssistantDiffSchema>;
 type EditorAssistantDiff = z.infer<typeof EditorAssistantDiffSchema>;
-type EditorAssistantResponse = z.infer<typeof EditorAssistantResponseSchema>;
 
 
 type ReqBody = {
 type ReqBody = {
   userMessage: string,
   userMessage: string,
@@ -56,13 +53,73 @@ type ReqBody = {
   aiAssistantId?: string,
   aiAssistantId?: string,
   threadId?: string,
   threadId?: string,
 }
 }
-
 type Req = Request<undefined, Response, ReqBody> & {
 type Req = Request<undefined, Response, ReqBody> & {
   user: IUserHasId,
   user: IUserHasId,
 }
 }
-
 type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
 type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
 
 
+/**
+ * 型ガード: メッセージ型かどうかを判定する
+ */
+const isMessageItem = (item: unknown): item is EditorAssistantMessage => {
+  return typeof item === 'object' && item !== null && 'message' in item;
+};
+
+/**
+ * 型ガード: 差分型かどうかを判定する
+ */
+const isDiffItem = (item: unknown): item is EditorAssistantDiff => {
+  return typeof item === 'object' && item !== null
+    && 'start' in item && 'end' in item && 'text' in item;
+};
+
+/**
+ * コンテンツからメッセージと差分を抽出する
+ */
+const extractContentItems = (contents: unknown[]) => {
+  const messages: string[] = [];
+  const replacements: EditorAssistantDiff[] = [];
+
+  contents.forEach((item) => {
+    if (isMessageItem(item)) {
+      messages.push(item.message);
+    }
+    else if (isDiffItem(item)) {
+      const validDiff = EditorAssistantDiffSchema.safeParse(item);
+      if (validDiff.success) {
+        replacements.push(validDiff.data);
+      }
+    }
+  });
+
+  return { messages, replacements };
+};
+
+/**
+ * エディターアシスタントのレスポンスデータを作成する
+ */
+const createEditorResponse = (messages: string[], replacements: EditorAssistantDiff[], isDone = false) => ({
+  editorResponse: {
+    message: messages.length > 0 ? messages[messages.length - 1] : '',
+    replacements,
+  },
+  ...(isDone ? { isDone: true } : {}),
+});
+
+/**
+ * SSEフォーマットでデータを送信する
+ */
+const writeSSEData = (res: Response, data: unknown) => {
+  res.write(`data: ${JSON.stringify(data)}\n\n`);
+};
+
+/**
+ * SSEフォーマットでエラーを送信する
+ */
+const writeSSEError = (res: Response, message: string, code?: StreamErrorCode) => {
+  res.write(`error: ${JSON.stringify({ code, message })}\n\n`);
+};
+
 export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (crowi) => {
 export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (crowi) => {
   const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
   const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
 
 
@@ -84,28 +141,37 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
   return [
   return [
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: Req, res: ApiV3Response) => {
     async(req: Req, res: ApiV3Response) => {
-      const {
-        userMessage, markdown, aiAssistantId, threadId,
-      } = req.body;
+      const { userMessage, markdown, threadId } = req.body;
 
 
+      // パラメータチェック
       if (threadId == null) {
       if (threadId == null) {
         return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
         return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
       }
       }
 
 
+      // サービスチェック
       const openaiService = getOpenaiService();
       const openaiService = getOpenaiService();
       if (openaiService == null) {
       if (openaiService == null) {
         return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
         return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
       }
       }
 
 
-      let stream: AssistantStream;
+      // レスポンスデータ格納用
+      const messages: string[] = [];
+      const replacements: EditorAssistantDiff[] = [];
+      let rawBuffer = '';
 
 
       try {
       try {
-        const assistant = await getOrCreateEditorAssistant();
+        // レスポンスヘッダー設定
+        res.writeHead(200, {
+          'Content-Type': 'text/event-stream;charset=utf-8',
+          'Cache-Control': 'no-cache, no-transform',
+        });
 
 
+        // アシスタント取得とスレッド処理
+        const assistant = await getOrCreateEditorAssistant();
         const thread = await openaiClient.beta.threads.retrieve(threadId);
         const thread = await openaiClient.beta.threads.retrieve(threadId);
 
 
-        // zodResponseFormatを使用して構造化された出力を得る
-        stream = openaiClient.beta.threads.runs.stream(thread.id, {
+        // ストリーム作成
+        const stream = openaiClient.beta.threads.runs.stream(thread.id, {
           assistant_id: assistant.id,
           assistant_id: assistant.id,
           additional_messages: [
           additional_messages: [
             {
             {
@@ -139,164 +205,116 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
           ],
           ],
           response_format: zodResponseFormat(EditorAssistantResponseSchema, 'editor_assistant_response'),
           response_format: zodResponseFormat(EditorAssistantResponseSchema, 'editor_assistant_response'),
         });
         });
-      }
-      catch (err) {
-        logger.error(err);
-        return res.status(500).send(err.message);
-      }
-
-      res.writeHead(200, {
-        'Content-Type': 'text/event-stream;charset=utf-8',
-        'Cache-Control': 'no-cache, no-transform',
-      });
-
-      // JSON解析用の変数
-      let rawBuffer = ''; // 完全な生の入力データを格納
-
-      // クライアントに送信するためのデータ
-      const messages: string[] = []; // 見つかったメッセージを格納
-      const replacements: EditorAssistantDiff[] = []; // 見つかった差分を格納
-
-      const messageDeltaHandler = async(delta: MessageDelta) => {
-        const content = delta.content?.[0];
 
 
-        // アノテーション処理は同様に行う
-        if (content?.type === 'text' && content?.text?.annotations != null) {
-          await replaceAnnotationWithPageLink(content, req.user.lang);
-        }
+        // メッセージデルタハンドラ
+        const messageDeltaHandler = async(delta: MessageDelta) => {
+          const content = delta.content?.[0];
 
 
-        if (content?.type === 'text' && content.text?.value) {
-          const chunk = content.text.value;
-          rawBuffer += chunk;
-
-          // バッファからストリームを作成
-          const bufferStream = Readable.from([rawBuffer]);
-
-          // JSONパーサーを設定
-          const jsonParser = bufferStream.pipe(parser()).pipe(streamValues());
+          // アノテーション処理
+          if (content?.type === 'text' && content?.text?.annotations != null) {
+            await replaceAnnotationWithPageLink(content, req.user.lang);
+          }
 
 
-          jsonParser.on('data', ({ value, key }) => {
-            // contentsアレイ内の要素を検出
-            if (Array.isArray(value.contents)) {
-              // 完全なcontentsアレイが見つかった場合
-              value.contents.forEach((item) => {
-                if ('message' in item) {
-                  messages.push(item.message);
-                }
-                else if ('start' in item && 'end' in item && 'text' in item) {
-                  const validDiff = EditorAssistantDiffSchema.safeParse(item);
-                  if (validDiff.success) {
-                    replacements.push(validDiff.data);
-                  }
+          // テキスト処理
+          if (content?.type === 'text' && content.text?.value) {
+            const chunk = content.text.value;
+            rawBuffer += chunk;
+
+            // JSONパース処理
+            try {
+              // ストリームから解析
+              const bufferStream = Readable.from([rawBuffer]);
+              const jsonParser = bufferStream.pipe(parser()).pipe(streamValues());
+
+              jsonParser.on('data', ({ value }) => {
+                // contentsアレイの処理
+                if (value?.contents && Array.isArray(value.contents)) {
+                  // メッセージと差分情報の抽出
+                  const extracted = extractContentItems(value.contents);
+                  messages.push(...extracted.messages);
+                  replacements.push(...extracted.replacements);
+
+                  // データ送信
+                  writeSSEData(res, createEditorResponse(messages, replacements));
                 }
                 }
               });
               });
-
-              // 更新をクライアントに送信
-              res.write(`data: ${JSON.stringify({
-                editorResponse: {
-                  message: messages.length > 0 ? messages[messages.length - 1] : '',
-                  replacements,
-                },
-              })}\n\n`);
             }
             }
-          });
-
-          // 元のデルタも送信
-          res.write(`data: ${JSON.stringify(delta)}\n\n`);
-        }
-        else {
-          res.write(`data: ${JSON.stringify(delta)}\n\n`);
-        }
-      };
-
-      const sendError = (message: string, code?: StreamErrorCode) => {
-        res.write(`error: ${JSON.stringify({ code, message })}\n\n`);
-      };
-
-      stream.on('event', (delta) => {
-        if (delta.event === 'thread.run.failed') {
-          const errorMessage = delta.data.last_error?.message;
-          if (errorMessage == null) {
-            return;
-          }
-          logger.error(errorMessage);
-          sendError(errorMessage, getStreamErrorCode(errorMessage));
-        }
-      });
-
-      stream.on('messageDelta', messageDeltaHandler);
-
-      stream.once('messageDone', () => {
-        // 処理完了時、最終的なレスポンスを送信
-        // 最後の確認としてJSONを解析してみる
-        try {
-          // バッファから完全なJSONを解析
-          const parsedJson = JSON.parse(rawBuffer);
-
-          if (parsedJson?.contents && Array.isArray(parsedJson.contents)) {
-            // 最終的なメッセージと差分を収集
-            const finalMessages: string[] = [];
-            const finalReplacements: EditorAssistantDiff[] = [];
-
-            for (const item of parsedJson.contents) {
-              if ('message' in item) {
-                finalMessages.push(item.message);
-              }
-              else if ('start' in item && 'end' in item && 'text' in item) {
-                const validDiff = EditorAssistantDiffSchema.safeParse(item);
-                if (validDiff.success) {
-                  finalReplacements.push(validDiff.data);
-                }
-              }
+            catch (e) {
+              // JSON解析中のエラーは無視(おそらく不完全なJSONデータ)
             }
             }
 
 
-            // 最終レスポンスをクライアントに送信(これまでに部分的に送信したものがあっても、完全なデータを再送)
-            res.write(`data: ${JSON.stringify({
-              editorResponse: {
-                message: finalMessages.length > 0 ? finalMessages[finalMessages.length - 1] : '',
-                replacements: finalReplacements,
-              },
-              isDone: true,
-            })}\n\n`);
+            // 元のデルタも送信
+            writeSSEData(res, delta);
           }
           }
           else {
           else {
-            // 既に部分的に送信したデータがある場合はそれを最終データとして扱う
-            res.write(`data: ${JSON.stringify({
-              editorResponse: {
-                message: messages.length > 0 ? messages[messages.length - 1] : '',
-                replacements,
-              },
-              isDone: true,
-            })}\n\n`);
+            writeSSEData(res, delta);
           }
           }
-        }
-        catch (e) {
-          logger.error('Failed to parse final JSON response', e);
-          // パース失敗時で既存のデータがある場合はそれを送信
-          if (messages.length > 0 || replacements.length > 0) {
-            res.write(`data: ${JSON.stringify({
-              editorResponse: {
-                message: messages.length > 0 ? messages[messages.length - 1] : '',
-                replacements,
-              },
-              isDone: true,
-            })}\n\n`);
+        };
+
+        // イベントハンドラ登録
+        stream.on('messageDelta', messageDeltaHandler);
+
+        // Runエラーハンドラ
+        stream.on('event', (delta) => {
+          if (delta.event === 'thread.run.failed') {
+            const errorMessage = delta.data.last_error?.message;
+            if (errorMessage == null) return;
+
+            logger.error(errorMessage);
+            writeSSEError(res, errorMessage, getStreamErrorCode(errorMessage));
           }
           }
-          else {
-            sendError('Failed to parse assistant response as JSON', 'INVALID_RESPONSE_FORMAT');
+        });
+
+        // 完了ハンドラ
+        stream.once('messageDone', () => {
+          // 最終確認として完全なJSONをパース
+          try {
+            const parsedJson = JSON.parse(rawBuffer);
+
+            if (parsedJson?.contents && Array.isArray(parsedJson.contents)) {
+              // 最終的なメッセージと差分を収集
+              const extracted = extractContentItems(parsedJson.contents);
+
+              // 最終データ送信
+              writeSSEData(res, createEditorResponse(
+                extracted.messages.length > 0 ? extracted.messages : messages,
+                extracted.replacements.length > 0 ? extracted.replacements : replacements,
+                true,
+              ));
+            }
+            else if (messages.length > 0 || replacements.length > 0) {
+              // パース結果が期待形式でなくても、部分的なデータがあれば送信
+              writeSSEData(res, createEditorResponse(messages, replacements, true));
+            }
           }
           }
-        }
+          catch (e) {
+            logger.error('Failed to parse final JSON response', e);
 
 
-        stream.off('messageDelta', messageDeltaHandler);
-        res.end();
-      });
+            if (messages.length > 0 || replacements.length > 0) {
+              // パース失敗でも、既存データがあれば送信
+              writeSSEData(res, createEditorResponse(messages, replacements, true));
+            }
+            else {
+              writeSSEError(res, 'Failed to parse assistant response as JSON', 'INVALID_RESPONSE_FORMAT');
+            }
+          }
 
 
-      stream.once('error', (err) => {
+          stream.off('messageDelta', messageDeltaHandler);
+          res.end();
+        });
+
+        // エラーハンドラ
+        stream.once('error', (err) => {
+          logger.error(err);
+          stream.off('messageDelta', messageDeltaHandler);
+          writeSSEError(res, 'An error occurred while processing your request');
+          res.end();
+        });
+      }
+      catch (err) {
         logger.error(err);
         logger.error(err);
-        stream.off('messageDelta', messageDeltaHandler);
-        sendError('An error occurred while processing your request');
-        res.end();
-      });
+        return res.status(500).send(err.message);
+      }
     },
     },
   ];
   ];
 };
 };