|
@@ -1,10 +1,16 @@
|
|
|
|
|
+import { Readable } from 'stream';
|
|
|
|
|
+
|
|
|
import type { IUserHasId } from '@growi/core/dist/interfaces';
|
|
import type { IUserHasId } from '@growi/core/dist/interfaces';
|
|
|
import { ErrorV3 } from '@growi/core/dist/models';
|
|
import { ErrorV3 } from '@growi/core/dist/models';
|
|
|
import type { Request, RequestHandler, Response } from 'express';
|
|
import type { Request, RequestHandler, Response } from 'express';
|
|
|
import type { ValidationChain } from 'express-validator';
|
|
import type { ValidationChain } from 'express-validator';
|
|
|
import { body } from 'express-validator';
|
|
import { body } from 'express-validator';
|
|
|
|
|
+import { zodResponseFormat } from 'openai/helpers/zod';
|
|
|
import type { AssistantStream } from 'openai/lib/AssistantStream';
|
|
import type { AssistantStream } from 'openai/lib/AssistantStream';
|
|
|
import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
|
|
import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
|
|
|
|
|
+import { parser } from 'stream-json';
|
|
|
|
|
+import { streamValues } from 'stream-json/streamers/StreamValues';
|
|
|
|
|
+import { z } from 'zod';
|
|
|
|
|
|
|
|
import { getOrCreateEditorAssistant } from '~/features/openai/server/services/assistant';
|
|
import { getOrCreateEditorAssistant } from '~/features/openai/server/services/assistant';
|
|
|
import type Crowi from '~/server/crowi';
|
|
import type Crowi from '~/server/crowi';
|
|
@@ -23,6 +29,26 @@ import { certifyAiService } from './middlewares/certify-ai-service';
|
|
|
|
|
|
|
|
const logger = loggerFactory('growi:routes:apiv3:openai:message');
|
|
const logger = loggerFactory('growi:routes:apiv3:openai:message');
|
|
|
|
|
|
|
|
|
|
+// 差分情報のスキーマ定義
|
|
|
|
|
+const EditorAssistantMessageSchema = z.object({
|
|
|
|
|
+ message: z.string().describe('A friendly message explaining what changes were made or suggested'),
|
|
|
|
|
+});
|
|
|
|
|
+
|
|
|
|
|
+const EditorAssistantDiffSchema = z.object({
|
|
|
|
|
+ start: z.number().int().describe('Zero-based index where replacement should begin'),
|
|
|
|
|
+ end: z.number().int().describe('Zero-based index where replacement should end'),
|
|
|
|
|
+ text: z.string().describe('The text that should replace the content between start and end positions'),
|
|
|
|
|
+});
|
|
|
|
|
+
|
|
|
|
|
+// 新しいレスポンス形式:
|
|
|
|
|
+const EditorAssistantResponseSchema = z.object({
|
|
|
|
|
+ contents: z.array(z.union([EditorAssistantMessageSchema, EditorAssistantDiffSchema])),
|
|
|
|
|
+}).describe('The response format for the editor assistant');
|
|
|
|
|
+
|
|
|
|
|
+// 型定義をZodスキーマから抽出
|
|
|
|
|
+type EditorAssistantMessage = z.infer<typeof EditorAssistantMessageSchema>;
|
|
|
|
|
+type EditorAssistantDiff = z.infer<typeof EditorAssistantDiffSchema>;
|
|
|
|
|
+type EditorAssistantResponse = z.infer<typeof EditorAssistantResponseSchema>;
|
|
|
|
|
|
|
|
type ReqBody = {
|
|
type ReqBody = {
|
|
|
userMessage: string,
|
|
userMessage: string,
|
|
@@ -48,9 +74,9 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
|
|
|
.withMessage('userMessage must be set'),
|
|
.withMessage('userMessage must be set'),
|
|
|
body('markdown')
|
|
body('markdown')
|
|
|
.isString()
|
|
.isString()
|
|
|
- .withMessage('userMessage must be string')
|
|
|
|
|
|
|
+ .withMessage('markdown must be string')
|
|
|
.notEmpty()
|
|
.notEmpty()
|
|
|
- .withMessage('userMessage must be set'),
|
|
|
|
|
|
|
+ .withMessage('markdown must be set'),
|
|
|
body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
|
|
body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
|
|
|
body('threadId').optional().isString().withMessage('threadId must be string'),
|
|
body('threadId').optional().isString().withMessage('threadId must be string'),
|
|
|
];
|
|
];
|
|
@@ -58,7 +84,9 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
|
|
|
return [
|
|
return [
|
|
|
accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
|
|
accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
|
|
|
async(req: Req, res: ApiV3Response) => {
|
|
async(req: Req, res: ApiV3Response) => {
|
|
|
- const { aiAssistantId, threadId } = req.body;
|
|
|
|
|
|
|
+ const {
|
|
|
|
|
+ userMessage, markdown, aiAssistantId, threadId,
|
|
|
|
|
+ } = req.body;
|
|
|
|
|
|
|
|
if (threadId == null) {
|
|
if (threadId == null) {
|
|
|
return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
|
|
return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
|
|
@@ -76,22 +104,44 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
|
|
|
|
|
|
|
|
const thread = await openaiClient.beta.threads.retrieve(threadId);
|
|
const thread = await openaiClient.beta.threads.retrieve(threadId);
|
|
|
|
|
|
|
|
|
|
+ // zodResponseFormatを使用して構造化された出力を得る
|
|
|
stream = openaiClient.beta.threads.runs.stream(thread.id, {
|
|
stream = openaiClient.beta.threads.runs.stream(thread.id, {
|
|
|
assistant_id: assistant.id,
|
|
assistant_id: assistant.id,
|
|
|
additional_messages: [
|
|
additional_messages: [
|
|
|
{
|
|
{
|
|
|
role: 'assistant',
|
|
role: 'assistant',
|
|
|
- content: '', // TODO: add a message to notify the user that the editing is started
|
|
|
|
|
|
|
+ content: `You are an Editor Assistant for GROWI, a markdown wiki system.
|
|
|
|
|
+ Your task is to help users edit their markdown content based on their requests.
|
|
|
|
|
+
|
|
|
|
|
+ RESPONSE FORMAT:
|
|
|
|
|
+ You must respond with a JSON object in the following format:
|
|
|
|
|
+ {
|
|
|
|
|
+ "contents": [
|
|
|
|
|
+ { "message": "Your friendly message explaining what changes were made or suggested" },
|
|
|
|
|
+ { "start": 0, "end": 10, "text": "New text 1" },
|
|
|
|
|
+ { "message": "Additional explanation if needed" },
|
|
|
|
|
+ { "start": 20, "end": 30, "text": "New text 2" },
|
|
|
|
|
+ ...more items if needed
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ The array should contain:
|
|
|
|
|
+ - Objects with a "message" key for explanatory text to the user
|
|
|
|
|
+ - Objects with "start", "end", and "text" keys for replacements
|
|
|
|
|
+
|
|
|
|
|
+ If no changes are needed, include only message objects with explanations.
|
|
|
|
|
+ Always provide messages in the same language as the user's request.`,
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ role: 'user',
|
|
|
|
|
+ content: `Current markdown content:\n\`\`\`markdown\n${markdown}\n\`\`\`\n\nUser request: ${userMessage}`,
|
|
|
},
|
|
},
|
|
|
- { role: 'user', content: req.body.userMessage },
|
|
|
|
|
],
|
|
],
|
|
|
|
|
+ response_format: zodResponseFormat(EditorAssistantResponseSchema, 'editor_assistant_response'),
|
|
|
});
|
|
});
|
|
|
-
|
|
|
|
|
}
|
|
}
|
|
|
catch (err) {
|
|
catch (err) {
|
|
|
logger.error(err);
|
|
logger.error(err);
|
|
|
-
|
|
|
|
|
- // TODO: improve error handling by https://redmine.weseek.co.jp/issues/155004
|
|
|
|
|
return res.status(500).send(err.message);
|
|
return res.status(500).send(err.message);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -100,15 +150,63 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
|
|
|
'Cache-Control': 'no-cache, no-transform',
|
|
'Cache-Control': 'no-cache, no-transform',
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
|
|
+ // JSON解析用の変数
|
|
|
|
|
+ let rawBuffer = ''; // 完全な生の入力データを格納
|
|
|
|
|
+
|
|
|
|
|
+ // クライアントに送信するためのデータ
|
|
|
|
|
+ const messages: string[] = []; // 見つかったメッセージを格納
|
|
|
|
|
+ const replacements: EditorAssistantDiff[] = []; // 見つかった差分を格納
|
|
|
|
|
+
|
|
|
const messageDeltaHandler = async(delta: MessageDelta) => {
|
|
const messageDeltaHandler = async(delta: MessageDelta) => {
|
|
|
const content = delta.content?.[0];
|
|
const content = delta.content?.[0];
|
|
|
|
|
|
|
|
- // If annotation is found
|
|
|
|
|
|
|
+ // アノテーション処理は同様に行う
|
|
|
if (content?.type === 'text' && content?.text?.annotations != null) {
|
|
if (content?.type === 'text' && content?.text?.annotations != null) {
|
|
|
await replaceAnnotationWithPageLink(content, req.user.lang);
|
|
await replaceAnnotationWithPageLink(content, req.user.lang);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- res.write(`data: ${JSON.stringify(delta)}\n\n`);
|
|
|
|
|
|
|
+ if (content?.type === 'text' && content.text?.value) {
|
|
|
|
|
+ const chunk = content.text.value;
|
|
|
|
|
+ rawBuffer += chunk;
|
|
|
|
|
+
|
|
|
|
|
+ // バッファからストリームを作成
|
|
|
|
|
+ const bufferStream = Readable.from([rawBuffer]);
|
|
|
|
|
+
|
|
|
|
|
+ // JSONパーサーを設定
|
|
|
|
|
+ const jsonParser = bufferStream.pipe(parser()).pipe(streamValues());
|
|
|
|
|
+
|
|
|
|
|
+ jsonParser.on('data', ({ value, key }) => {
|
|
|
|
|
+ // contentsアレイ内の要素を検出
|
|
|
|
|
+ if (Array.isArray(value.contents)) {
|
|
|
|
|
+ // 完全なcontentsアレイが見つかった場合
|
|
|
|
|
+ value.contents.forEach((item) => {
|
|
|
|
|
+ if ('message' in item) {
|
|
|
|
|
+ messages.push(item.message);
|
|
|
|
|
+ }
|
|
|
|
|
+ else if ('start' in item && 'end' in item && 'text' in item) {
|
|
|
|
|
+ const validDiff = EditorAssistantDiffSchema.safeParse(item);
|
|
|
|
|
+ if (validDiff.success) {
|
|
|
|
|
+ replacements.push(validDiff.data);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // 更新をクライアントに送信
|
|
|
|
|
+ res.write(`data: ${JSON.stringify({
|
|
|
|
|
+ editorResponse: {
|
|
|
|
|
+ message: messages.length > 0 ? messages[messages.length - 1] : '',
|
|
|
|
|
+ replacements,
|
|
|
|
|
+ },
|
|
|
|
|
+ })}\n\n`);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // 元のデルタも送信
|
|
|
|
|
+ res.write(`data: ${JSON.stringify(delta)}\n\n`);
|
|
|
|
|
+ }
|
|
|
|
|
+ else {
|
|
|
|
|
+ res.write(`data: ${JSON.stringify(delta)}\n\n`);
|
|
|
|
|
+ }
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
const sendError = (message: string, code?: StreamErrorCode) => {
|
|
const sendError = (message: string, code?: StreamErrorCode) => {
|
|
@@ -125,14 +223,78 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
|
|
|
sendError(errorMessage, getStreamErrorCode(errorMessage));
|
|
sendError(errorMessage, getStreamErrorCode(errorMessage));
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
|
|
+
|
|
|
stream.on('messageDelta', messageDeltaHandler);
|
|
stream.on('messageDelta', messageDeltaHandler);
|
|
|
|
|
+
|
|
|
stream.once('messageDone', () => {
|
|
stream.once('messageDone', () => {
|
|
|
|
|
+ // 処理完了時、最終的なレスポンスを送信
|
|
|
|
|
+ // 最後の確認としてJSONを解析してみる
|
|
|
|
|
+ try {
|
|
|
|
|
+ // バッファから完全なJSONを解析
|
|
|
|
|
+ const parsedJson = JSON.parse(rawBuffer);
|
|
|
|
|
+
|
|
|
|
|
+ if (parsedJson?.contents && Array.isArray(parsedJson.contents)) {
|
|
|
|
|
+ // 最終的なメッセージと差分を収集
|
|
|
|
|
+ const finalMessages: string[] = [];
|
|
|
|
|
+ const finalReplacements: EditorAssistantDiff[] = [];
|
|
|
|
|
+
|
|
|
|
|
+ for (const item of parsedJson.contents) {
|
|
|
|
|
+ if ('message' in item) {
|
|
|
|
|
+ finalMessages.push(item.message);
|
|
|
|
|
+ }
|
|
|
|
|
+ else if ('start' in item && 'end' in item && 'text' in item) {
|
|
|
|
|
+ const validDiff = EditorAssistantDiffSchema.safeParse(item);
|
|
|
|
|
+ if (validDiff.success) {
|
|
|
|
|
+ finalReplacements.push(validDiff.data);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 最終レスポンスをクライアントに送信(これまでに部分的に送信したものがあっても、完全なデータを再送)
|
|
|
|
|
+ res.write(`data: ${JSON.stringify({
|
|
|
|
|
+ editorResponse: {
|
|
|
|
|
+ message: finalMessages.length > 0 ? finalMessages[finalMessages.length - 1] : '',
|
|
|
|
|
+ replacements: finalReplacements,
|
|
|
|
|
+ },
|
|
|
|
|
+ isDone: true,
|
|
|
|
|
+ })}\n\n`);
|
|
|
|
|
+ }
|
|
|
|
|
+ else {
|
|
|
|
|
+ // 既に部分的に送信したデータがある場合はそれを最終データとして扱う
|
|
|
|
|
+ res.write(`data: ${JSON.stringify({
|
|
|
|
|
+ editorResponse: {
|
|
|
|
|
+ message: messages.length > 0 ? messages[messages.length - 1] : '',
|
|
|
|
|
+ replacements,
|
|
|
|
|
+ },
|
|
|
|
|
+ isDone: true,
|
|
|
|
|
+ })}\n\n`);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ catch (e) {
|
|
|
|
|
+ logger.error('Failed to parse final JSON response', e);
|
|
|
|
|
+ // パース失敗時で既存のデータがある場合はそれを送信
|
|
|
|
|
+ if (messages.length > 0 || replacements.length > 0) {
|
|
|
|
|
+ res.write(`data: ${JSON.stringify({
|
|
|
|
|
+ editorResponse: {
|
|
|
|
|
+ message: messages.length > 0 ? messages[messages.length - 1] : '',
|
|
|
|
|
+ replacements,
|
|
|
|
|
+ },
|
|
|
|
|
+ isDone: true,
|
|
|
|
|
+ })}\n\n`);
|
|
|
|
|
+ }
|
|
|
|
|
+ else {
|
|
|
|
|
+ sendError('Failed to parse assistant response as JSON', 'INVALID_RESPONSE_FORMAT');
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
stream.off('messageDelta', messageDeltaHandler);
|
|
stream.off('messageDelta', messageDeltaHandler);
|
|
|
res.end();
|
|
res.end();
|
|
|
});
|
|
});
|
|
|
|
|
+
|
|
|
stream.once('error', (err) => {
|
|
stream.once('error', (err) => {
|
|
|
logger.error(err);
|
|
logger.error(err);
|
|
|
stream.off('messageDelta', messageDeltaHandler);
|
|
stream.off('messageDelta', messageDeltaHandler);
|
|
|
|
|
+ sendError('An error occurred while processing your request');
|
|
|
res.end();
|
|
res.end();
|
|
|
});
|
|
});
|
|
|
},
|
|
},
|