post-message.ts 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import type { IUserHasId } from '@growi/core/dist/interfaces';
  2. import { ErrorV3 } from '@growi/core/dist/models';
  3. import type { Request, RequestHandler, Response } from 'express';
  4. import type { ValidationChain } from 'express-validator';
  5. import { body } from 'express-validator';
  6. import type { AssistantStream } from 'openai/lib/AssistantStream';
  7. import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
  8. import { getOrCreateChatAssistant } from '~/features/openai/server/services/assistant';
  9. import type Crowi from '~/server/crowi';
  10. import { accessTokenParser } from '~/server/middlewares/access-token-parser';
  11. import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
  12. import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
  13. import loggerFactory from '~/utils/logger';
  14. import { MessageErrorCode, type StreamErrorCode } from '../../../interfaces/message-error';
  15. import AiAssistantModel from '../../models/ai-assistant';
  16. import ThreadRelationModel from '../../models/thread-relation';
  17. import { openaiClient } from '../../services/client';
  18. import { getStreamErrorCode } from '../../services/getStreamErrorCode';
  19. import { getOpenaiService } from '../../services/openai';
  20. import { replaceAnnotationWithPageLink } from '../../services/replace-annotation-with-page-link';
  21. import { certifyAiService } from '../middlewares/certify-ai-service';
  22. const logger = loggerFactory('growi:routes:apiv3:openai:message');
  23. type ReqBody = {
  24. userMessage: string,
  25. aiAssistantId: string,
  26. threadId?: string,
  27. summaryMode?: boolean,
  28. extendedThinkingMode?: boolean,
  29. }
  30. type Req = Request<undefined, Response, ReqBody> & {
  31. user: IUserHasId,
  32. }
  33. type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
  34. export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) => {
  35. const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
  36. const validator: ValidationChain[] = [
  37. body('userMessage')
  38. .isString()
  39. .withMessage('userMessage must be string')
  40. .notEmpty()
  41. .withMessage('userMessage must be set'),
  42. body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
  43. body('threadId').optional().isString().withMessage('threadId must be string'),
  44. ];
  45. return [
  46. accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
  47. async(req: Req, res: ApiV3Response) => {
  48. const { aiAssistantId, threadId } = req.body;
  49. if (threadId == null) {
  50. return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
  51. }
  52. const openaiService = getOpenaiService();
  53. if (openaiService == null) {
  54. return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
  55. }
  56. const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
  57. if (!isAiAssistantUsable) {
  58. return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
  59. }
  60. const aiAssistant = await AiAssistantModel.findById(aiAssistantId);
  61. if (aiAssistant == null) {
  62. return res.apiv3Err(new ErrorV3('AI assistant not found'), 404);
  63. }
  64. const threadRelation = await ThreadRelationModel.findOne({ threadId });
  65. if (threadRelation == null) {
  66. return res.apiv3Err(new ErrorV3('ThreadRelation not found'), 404);
  67. }
  68. threadRelation.updateThreadExpiration();
  69. let stream: AssistantStream;
  70. const useSummaryMode = req.body.summaryMode ?? false;
  71. const useExtendedThinkingMode = req.body.extendedThinkingMode ?? false;
  72. try {
  73. const assistant = await getOrCreateChatAssistant();
  74. const thread = await openaiClient.beta.threads.retrieve(threadId);
  75. stream = openaiClient.beta.threads.runs.stream(thread.id, {
  76. assistant_id: assistant.id,
  77. additional_messages: [
  78. { role: 'user', content: req.body.userMessage },
  79. ],
  80. additional_instructions: [
  81. aiAssistant.additionalInstruction,
  82. useSummaryMode
  83. ? '**IMPORTANT** : Turn on "Summary Mode"'
  84. : '**IMPORTANT** : Turn off "Summary Mode"',
  85. useExtendedThinkingMode
  86. ? '**IMPORTANT** : Turn on "Extended Thinking Mode"'
  87. : '**IMPORTANT** : Turn off "Extended Thinking Mode"',
  88. ].join('\n'),
  89. });
  90. }
  91. catch (err) {
  92. logger.error(err);
  93. // TODO: improve error handling by https://redmine.weseek.co.jp/issues/155004
  94. return res.status(500).send(err.message);
  95. }
  96. res.writeHead(200, {
  97. 'Content-Type': 'text/event-stream;charset=utf-8',
  98. 'Cache-Control': 'no-cache, no-transform',
  99. });
  100. const messageDeltaHandler = async(delta: MessageDelta) => {
  101. const content = delta.content?.[0];
  102. // If annotation is found
  103. if (content?.type === 'text' && content?.text?.annotations != null) {
  104. await replaceAnnotationWithPageLink(content, req.user.lang);
  105. }
  106. res.write(`data: ${JSON.stringify(delta)}\n\n`);
  107. };
  108. const sendError = (message: string, code?: StreamErrorCode) => {
  109. res.write(`error: ${JSON.stringify({ code, message })}\n\n`);
  110. };
  111. stream.on('event', (delta) => {
  112. if (delta.event === 'thread.run.failed') {
  113. const errorMessage = delta.data.last_error?.message;
  114. if (errorMessage == null) {
  115. return;
  116. }
  117. logger.error(errorMessage);
  118. sendError(errorMessage, getStreamErrorCode(errorMessage));
  119. }
  120. });
  121. stream.on('messageDelta', messageDeltaHandler);
  122. stream.once('messageDone', () => {
  123. stream.off('messageDelta', messageDeltaHandler);
  124. res.end();
  125. });
  126. stream.once('error', (err) => {
  127. logger.error(err);
  128. stream.off('messageDelta', messageDeltaHandler);
  129. res.end();
  130. });
  131. },
  132. ];
  133. };