Преглед изворни кода

WIP: add a route for editor assistant

Yuki Takei пре 1 година
родитељ
комит
119fd19e05

+ 140 - 0
apps/app/src/features/openai/server/routes/edit.ts

@@ -0,0 +1,140 @@
+import type { IUserHasId } from '@growi/core/dist/interfaces';
+import { ErrorV3 } from '@growi/core/dist/models';
+import type { Request, RequestHandler, Response } from 'express';
+import type { ValidationChain } from 'express-validator';
+import { body } from 'express-validator';
+import type { AssistantStream } from 'openai/lib/AssistantStream';
+import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
+
+import { getOrCreateEditorAssistant } from '~/features/openai/server/services/assistant';
+import type Crowi from '~/server/crowi';
+import { accessTokenParser } from '~/server/middlewares/access-token-parser';
+import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
+import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
+import loggerFactory from '~/utils/logger';
+
+import { MessageErrorCode, type StreamErrorCode } from '../../interfaces/message-error';
+import { openaiClient } from '../services/client';
+import { getStreamErrorCode } from '../services/getStreamErrorCode';
+import { getOpenaiService } from '../services/openai';
+import { replaceAnnotationWithPageLink } from '../services/replace-annotation-with-page-link';
+
+import { certifyAiService } from './middlewares/certify-ai-service';
+
+const logger = loggerFactory('growi:routes:apiv3:openai:message');
+
+
+type ReqBody = {
+  userMessage: string,
+  markdown: string,
+  aiAssistantId?: string,
+  threadId?: string,
+}
+
+type Req = Request<undefined, Response, ReqBody> & {
+  user: IUserHasId,
+}
+
+type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
+
+export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (crowi) => {
+  const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
+
+  const validator: ValidationChain[] = [
+    body('userMessage')
+      .isString()
+      .withMessage('userMessage must be string')
+      .notEmpty()
+      .withMessage('userMessage must be set'),
+    body('markdown')
+      .isString()
+      .withMessage('userMessage must be string')
+      .notEmpty()
+      .withMessage('userMessage must be set'),
+    body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
+    body('threadId').optional().isString().withMessage('threadId must be string'),
+  ];
+
+  return [
+    accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
+    async(req: Req, res: ApiV3Response) => {
+      const { aiAssistantId, threadId } = req.body;
+
+      if (threadId == null) {
+        return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
+      }
+
+      const openaiService = getOpenaiService();
+      if (openaiService == null) {
+        return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
+      }
+
+      let stream: AssistantStream;
+
+      try {
+        const assistant = await getOrCreateEditorAssistant();
+
+        const thread = await openaiClient.beta.threads.retrieve(threadId);
+
+        stream = openaiClient.beta.threads.runs.stream(thread.id, {
+          assistant_id: assistant.id,
+          additional_messages: [
+            {
+              role: 'assistant',
+              content: '', // TODO: add a message to notify the user that the editing is started
+            },
+            { role: 'user', content: req.body.userMessage },
+          ],
+        });
+
+      }
+      catch (err) {
+        logger.error(err);
+
+        // TODO: improve error handling by https://redmine.weseek.co.jp/issues/155004
+        return res.status(500).send(err.message);
+      }
+
+      res.writeHead(200, {
+        'Content-Type': 'text/event-stream;charset=utf-8',
+        'Cache-Control': 'no-cache, no-transform',
+      });
+
+      const messageDeltaHandler = async(delta: MessageDelta) => {
+        const content = delta.content?.[0];
+
+        // If annotation is found
+        if (content?.type === 'text' && content?.text?.annotations != null) {
+          await replaceAnnotationWithPageLink(content, req.user.lang);
+        }
+
+        res.write(`data: ${JSON.stringify(delta)}\n\n`);
+      };
+
+      const sendError = (message: string, code?: StreamErrorCode) => {
+        res.write(`error: ${JSON.stringify({ code, message })}\n\n`);
+      };
+
+      stream.on('event', (delta) => {
+        if (delta.event === 'thread.run.failed') {
+          const errorMessage = delta.data.last_error?.message;
+          if (errorMessage == null) {
+            return;
+          }
+          logger.error(errorMessage);
+          sendError(errorMessage, getStreamErrorCode(errorMessage));
+        }
+      });
+      stream.on('messageDelta', messageDeltaHandler);
+      stream.once('messageDone', () => {
+        stream.off('messageDelta', messageDeltaHandler);
+        res.end();
+      });
+      stream.once('error', (err) => {
+        logger.error(err);
+        stream.off('messageDelta', messageDeltaHandler);
+        res.end();
+      });
+    },
+  ];
+};

+ 4 - 0
apps/app/src/features/openai/server/routes/index.ts

@@ -39,6 +39,10 @@ export const factory = (crowi: Crowi): express.Router => {
       router.get('/messages/:aiAssistantId/:threadId', getMessagesFactory(crowi));
     });
 
+    import('./edit').then(({ postMessageToEditHandlersFactory }) => {
+      router.post('/edit', postMessageToEditHandlersFactory(crowi));
+    });
+
     import('./ai-assistant').then(({ createAiAssistantFactory }) => {
       router.post('/ai-assistant', createAiAssistantFactory(crowi));
     });

+ 25 - 12
apps/app/src/features/openai/server/services/assistant/assistant.ts

@@ -8,27 +8,30 @@ import { openaiClient } from '../client';
 const AssistantType = {
   SEARCH: 'Search',
   CHAT: 'Chat',
+  EDIT: 'Edit',
 } as const;
 
 const AssistantDefaultModelMap: Record<AssistantType, OpenAI.Chat.ChatModel> = {
   [AssistantType.SEARCH]: 'gpt-4o-mini',
   [AssistantType.CHAT]: 'gpt-4o-mini',
-};
-
-const isValidChatModel = (model: string): model is OpenAI.Chat.ChatModel => {
-  return model.startsWith('gpt-');
+  [AssistantType.EDIT]: 'gpt-4o-mini',
 };
 
 const getAssistantModelByType = (type: AssistantType): OpenAI.Chat.ChatModel => {
-  const configValue = type === AssistantType.SEARCH
-    ? undefined // TODO: add the value for 'openai:assistantModel:search' to config-definition.ts
-    : configManager.getConfig('openai:assistantModel:chat');
-
-  if (typeof configValue === 'string' && isValidChatModel(configValue)) {
-    return configValue;
-  }
+  const configValue = (() => {
+    switch (type) {
+      case AssistantType.SEARCH:
+        // return configManager.getConfig('openai:assistantModel:search');
+        return undefined;
+      case AssistantType.CHAT:
+        return configManager.getConfig('openai:assistantModel:chat');
+      case AssistantType.EDIT:
+        // return configManager.getConfig('openai:assistantModel:edit');
+        return undefined;
+    }
+  })();
 
-  return AssistantDefaultModelMap[type];
+  return configValue ?? AssistantDefaultModelMap[type];
 };
 
 type AssistantType = typeof AssistantType[keyof typeof AssistantType];
@@ -103,3 +106,13 @@ export const getOrCreateChatAssistant = async(): Promise<OpenAI.Beta.Assistant>
   chatAssistant = await getOrCreateAssistant(AssistantType.CHAT);
   return chatAssistant;
 };
+
+let editorAssistant: OpenAI.Beta.Assistant | undefined;
+export const getOrCreateEditorAssistant = async(): Promise<OpenAI.Beta.Assistant> => {
+  if (editorAssistant != null) {
+    return editorAssistant;
+  }
+
+  editorAssistant = await getOrCreateAssistant(AssistantType.EDIT);
+  return editorAssistant;
+};