index.ts 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import type { IUserHasId } from '@growi/core/dist/interfaces';
  2. import { ErrorV3 } from '@growi/core/dist/models';
  3. import type { Request, RequestHandler, Response } from 'express';
  4. import type { ValidationChain } from 'express-validator';
  5. import { body } from 'express-validator';
  6. import { zodResponseFormat } from 'openai/helpers/zod';
  7. import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
  8. import { z } from 'zod';
  9. // 必要なインポート
  10. import { getOrCreateEditorAssistant } from '~/features/openai/server/services/assistant';
  11. import type Crowi from '~/server/crowi';
  12. import { accessTokenParser } from '~/server/middlewares/access-token-parser';
  13. import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
  14. import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
  15. import loggerFactory from '~/utils/logger';
  16. import { MessageErrorCode } from '../../../interfaces/message-error';
  17. import { openaiClient } from '../../services/client';
  18. import { getStreamErrorCode } from '../../services/getStreamErrorCode';
  19. import { getOpenaiService } from '../../services/openai';
  20. import { replaceAnnotationWithPageLink } from '../../services/replace-annotation-with-page-link';
  21. import { certifyAiService } from '../middlewares/certify-ai-service';
  22. import { SseHelper } from '../utils/sse-helper';
  23. import { EditorStreamProcessor } from './editor-stream-processor';
  24. import { EditorAssistantDiffSchema, EditorAssistantMessageSchema } from './schema';
  25. const logger = loggerFactory('growi:routes:apiv3:openai:message');
  26. // -----------------------------------------------------------------------------
  27. // 型定義
  28. // -----------------------------------------------------------------------------
  29. const EditorAssistantResponseSchema = z.object({
  30. contents: z.array(z.union([EditorAssistantMessageSchema, EditorAssistantDiffSchema])),
  31. }).describe('The response format for the editor assistant');
  32. type ReqBody = {
  33. userMessage: string,
  34. markdown: string,
  35. aiAssistantId?: string,
  36. threadId?: string,
  37. }
  38. type Req = Request<undefined, Response, ReqBody> & {
  39. user: IUserHasId,
  40. }
  41. // -----------------------------------------------------------------------------
  42. // エンドポイントハンドラーファクトリ
  43. // -----------------------------------------------------------------------------
  44. type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
  45. /**
  46. * エディタアシスタントのエンドポイントハンドラを作成する
  47. */
  48. export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (crowi) => {
  49. const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
  50. // バリデータ設定
  51. const validator: ValidationChain[] = [
  52. body('userMessage')
  53. .isString()
  54. .withMessage('userMessage must be string')
  55. .notEmpty()
  56. .withMessage('userMessage must be set'),
  57. body('markdown')
  58. .isString()
  59. .withMessage('markdown must be string')
  60. .notEmpty()
  61. .withMessage('markdown must be set'),
  62. body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
  63. body('threadId').optional().isString().withMessage('threadId must be string'),
  64. ];
  65. return [
  66. accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
  67. async(req: Req, res: ApiV3Response) => {
  68. const { userMessage, markdown, threadId } = req.body;
  69. // パラメータチェック
  70. if (threadId == null) {
  71. return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
  72. }
  73. // サービスチェック
  74. const openaiService = getOpenaiService();
  75. if (openaiService == null) {
  76. return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
  77. }
  78. // SSEヘルパーとストリームプロセッサの初期化
  79. const sseHelper = new SseHelper(res);
  80. const streamProcessor = new EditorStreamProcessor(sseHelper);
  81. try {
  82. // レスポンスヘッダー設定
  83. res.writeHead(200, {
  84. 'Content-Type': 'text/event-stream;charset=utf-8',
  85. 'Cache-Control': 'no-cache, no-transform',
  86. });
  87. let rawBuffer = '';
  88. // アシスタント取得とスレッド処理
  89. const assistant = await getOrCreateEditorAssistant();
  90. const thread = await openaiClient.beta.threads.retrieve(threadId);
  91. // ストリーム作成
  92. const stream = openaiClient.beta.threads.runs.stream(thread.id, {
  93. assistant_id: assistant.id,
  94. additional_messages: [
  95. {
  96. role: 'assistant',
  97. content: `You are an Editor Assistant for GROWI, a markdown wiki system.
  98. Your task is to help users edit their markdown content based on their requests.
  99. RESPONSE FORMAT:
  100. You must respond with a JSON object in the following format:
  101. {
  102. "contents": [
  103. { "message": "Your friendly message explaining what changes were made or suggested" },
  104. { "start": 0, "end": 10, "text": "New text 1" },
  105. { "message": "Additional explanation if needed" },
  106. { "start": 20, "end": 30, "text": "New text 2" },
  107. ...more items if needed
  108. ]
  109. }
  110. The array should contain:
  111. - Objects with a "message" key for explanatory text to the user
  112. - Objects with "start", "end", and "text" keys for replacements
  113. If no changes are needed, include only message objects with explanations.
  114. Always provide messages in the same language as the user's request.`,
  115. },
  116. {
  117. role: 'user',
  118. content: `Current markdown content:\n\`\`\`markdown\n${markdown}\n\`\`\`\n\nUser request: ${userMessage}`,
  119. },
  120. ],
  121. response_format: zodResponseFormat(EditorAssistantResponseSchema, 'editor_assistant_response'),
  122. });
  123. // メッセージデルタハンドラ
  124. const messageDeltaHandler = async(delta: MessageDelta) => {
  125. const content = delta.content?.[0];
  126. // アノテーション処理
  127. if (content?.type === 'text' && content?.text?.annotations != null) {
  128. await replaceAnnotationWithPageLink(content, req.user.lang);
  129. }
  130. // テキスト処理
  131. if (content?.type === 'text' && content.text?.value) {
  132. const chunk = content.text.value;
  133. rawBuffer += chunk;
  134. // JSONプロセッサでデータを処理
  135. streamProcessor.process(rawBuffer);
  136. // 元のデルタも送信
  137. sseHelper.writeData(delta);
  138. }
  139. else {
  140. sseHelper.writeData(delta);
  141. }
  142. };
  143. // イベントハンドラ登録
  144. stream.on('messageDelta', messageDeltaHandler);
  145. // Runエラーハンドラ
  146. stream.on('event', (delta) => {
  147. if (delta.event === 'thread.run.failed') {
  148. const errorMessage = delta.data.last_error?.message;
  149. if (errorMessage == null) return;
  150. logger.error(errorMessage);
  151. sseHelper.writeError(errorMessage, getStreamErrorCode(errorMessage));
  152. }
  153. });
  154. // 完了ハンドラ
  155. stream.once('messageDone', () => {
  156. // 最終結果を処理して送信
  157. streamProcessor.sendFinalResult(rawBuffer);
  158. // ストリームのクリーンアップ
  159. streamProcessor.destroy();
  160. stream.off('messageDelta', messageDeltaHandler);
  161. sseHelper.end();
  162. });
  163. // エラーハンドラ
  164. stream.once('error', (err) => {
  165. logger.error('Stream error:', err);
  166. // クリーンアップ
  167. streamProcessor.destroy();
  168. stream.off('messageDelta', messageDeltaHandler);
  169. sseHelper.writeError('An error occurred while processing your request');
  170. sseHelper.end();
  171. });
  172. // クライアント切断時のクリーンアップ
  173. req.on('close', () => {
  174. streamProcessor.destroy();
  175. if (stream) {
  176. stream.off('messageDelta', () => {});
  177. stream.off('event', () => {});
  178. }
  179. logger.debug('Connection closed by client');
  180. });
  181. }
  182. catch (err) {
  183. // エラー発生時のクリーンアップと応答
  184. logger.error('Error in edit handler:', err);
  185. streamProcessor.destroy();
  186. return res.status(500).send(err.message);
  187. }
  188. },
  189. ];
  190. };