message.ts 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import type { IUserHasId } from '@growi/core/dist/interfaces';
  2. import { ErrorV3 } from '@growi/core/dist/models';
  3. import type { Request, RequestHandler, Response } from 'express';
  4. import type { ValidationChain } from 'express-validator';
  5. import { body } from 'express-validator';
  6. import type { AssistantStream } from 'openai/lib/AssistantStream';
  7. import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
  8. import { getOrCreateChatAssistant } from '~/features/openai/server/services/assistant';
  9. import type Crowi from '~/server/crowi';
  10. import { accessTokenParser } from '~/server/middlewares/access-token-parser';
  11. import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
  12. import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
  13. import loggerFactory from '~/utils/logger';
  14. import { MessageErrorCode, type StreamErrorCode } from '../../interfaces/message-error';
  15. import { openaiClient } from '../services/client';
  16. import { getStreamErrorCode } from '../services/getStreamErrorCode';
  17. import { replaceAnnotationWithPageLink } from '../services/replace-annotation-with-page-link';
  18. import { certifyAiService } from './middlewares/certify-ai-service';
  19. const logger = loggerFactory('growi:routes:apiv3:openai:message');
  20. type ReqBody = {
  21. userMessage: string,
  22. threadId?: string,
  23. summaryMode?: boolean,
  24. }
  25. type Req = Request<undefined, Response, ReqBody> & {
  26. user: IUserHasId,
  27. }
  28. type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
  29. export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) => {
  30. const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
  31. const validator: ValidationChain[] = [
  32. body('userMessage')
  33. .isString()
  34. .withMessage('userMessage must be string')
  35. .notEmpty()
  36. .withMessage('userMessage must be set'),
  37. body('threadId').optional().isString().withMessage('threadId must be string'),
  38. ];
  39. return [
  40. accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
  41. async(req: Req, res: ApiV3Response) => {
  42. const threadId = req.body.threadId;
  43. if (threadId == null) {
  44. return res.apiv3Err(new ErrorV3('threadId is not set', MessageErrorCode.THREAD_ID_IS_NOT_SET), 400);
  45. }
  46. let stream: AssistantStream;
  47. try {
  48. const assistant = await getOrCreateChatAssistant();
  49. const thread = await openaiClient.beta.threads.retrieve(threadId);
  50. stream = openaiClient.beta.threads.runs.stream(thread.id, {
  51. assistant_id: assistant.id,
  52. additional_messages: [
  53. {
  54. role: 'assistant',
  55. content: req.body.summaryMode
  56. ? 'Turn on summary mode: I will try to answer concisely, aiming for 1-3 sentences.'
  57. : 'I will turn off summary mode and answer.',
  58. },
  59. { role: 'user', content: req.body.userMessage },
  60. ],
  61. });
  62. }
  63. catch (err) {
  64. logger.error(err);
  65. // TODO: improve error handling by https://redmine.weseek.co.jp/issues/155004
  66. return res.status(500).send(err.message);
  67. }
  68. res.writeHead(200, {
  69. 'Content-Type': 'text/event-stream;charset=utf-8',
  70. 'Cache-Control': 'no-cache, no-transform',
  71. });
  72. const messageDeltaHandler = async(delta: MessageDelta) => {
  73. const content = delta.content?.[0];
  74. // If annotation is found
  75. if (content?.type === 'text' && content?.text?.annotations != null) {
  76. await replaceAnnotationWithPageLink(content, req.user.lang);
  77. }
  78. res.write(`data: ${JSON.stringify(delta)}\n\n`);
  79. };
  80. const sendError = (message: string, code?: StreamErrorCode) => {
  81. res.write(`error: ${JSON.stringify({ code, message })}\n\n`);
  82. };
  83. stream.on('event', (delta) => {
  84. if (delta.event === 'thread.run.failed') {
  85. const errorMessage = delta.data.last_error?.message;
  86. if (errorMessage == null) {
  87. return;
  88. }
  89. logger.error(errorMessage);
  90. sendError(errorMessage, getStreamErrorCode(errorMessage));
  91. }
  92. });
  93. stream.on('messageDelta', messageDeltaHandler);
  94. stream.once('messageDone', () => {
  95. stream.off('messageDelta', messageDeltaHandler);
  96. res.end();
  97. });
  98. stream.once('error', (err) => {
  99. logger.error(err);
  100. stream.off('messageDelta', messageDeltaHandler);
  101. res.end();
  102. });
  103. },
  104. ];
  105. };