Yuki Takei 1 год назад
Родитель
Сommit
eafc2f4ad8

+ 8 - 0
apps/app/src/server/routes/apiv3/openai/index.ts

@@ -1,10 +1,18 @@
 import express from 'express';
 
 import { chatHandlersFactory } from './chat';
+import { postMessageHandlersFactory } from './message';
+import { createThreadHandlersFactory } from './thread';
 
 const router = express.Router();
 
 module.exports = (crowi) => {
+  // deprecated
   router.post('/chat', chatHandlersFactory(crowi));
+
+  // create thread
+  router.post('/thread', createThreadHandlersFactory(crowi));
+  // post message and return streaming with SSE
+  router.post('/message', postMessageHandlersFactory(crowi));
   return router;
 };

+ 77 - 0
apps/app/src/server/routes/apiv3/openai/message.ts

@@ -0,0 +1,77 @@
+import type { Request, RequestHandler } from 'express';
+import type { ValidationChain } from 'express-validator';
+import { body } from 'express-validator';
+import type { AssistantStream } from 'openai/lib/AssistantStream';
+
+import type Crowi from '~/server/crowi';
+import { openaiClient } from '~/server/service/openai';
+import { getOrCreateChatAssistant } from '~/server/service/openai/assistant';
+import loggerFactory from '~/utils/logger';
+
+import { apiV3FormValidator } from '../../../middlewares/apiv3-form-validator';
+import type { ApiV3Response } from '../interfaces/apiv3-response';
+
+
+const logger = loggerFactory('growi:routes:apiv3:openai:chat');
+
+type ReqBody = {
+  userMessage: string,
+  threadId?: string,
+}
+
+type Req = Request<undefined, ApiV3Response, ReqBody>
+
+type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
+
+export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) => {
+  const accessTokenParser = require('../../../middlewares/access-token-parser')(crowi);
+  const loginRequiredStrictly = require('../../../middlewares/login-required')(crowi);
+
+  const validator: ValidationChain[] = [
+    body('userMessage').isString().withMessage('userMessage must be string'),
+    body('threadId').isString().withMessage('threadId must be string'),
+  ];
+
+  return [
+    accessTokenParser, loginRequiredStrictly, validator, apiV3FormValidator,
+    async(req: Req, res: ApiV3Response) => {
+
+      const threadId = req.body.threadId;
+
+      assert(threadId != null);
+
+      let stream: AssistantStream;
+
+      try {
+        const assistant = await getOrCreateChatAssistant();
+
+        const thread = await openaiClient.beta.threads.retrieve(threadId);
+
+        stream = openaiClient.beta.threads.runs.stream(thread.id, {
+          assistant_id: assistant.id,
+          additional_messages: [{ role: 'assistant', content: req.body.userMessage }],
+        });
+
+      }
+      catch (err) {
+        logger.error(err);
+        return res.status(500).send(err);
+      }
+
+      res.setHeader('Content-Type', 'text/event-stream;charset=utf-8');
+      res.setHeader('Cache-Control', 'no-cache, no-transform');
+      res.setHeader('X-Accel-Buffering', 'no');
+
+      try {
+        for await (const data of stream) {
+          res.write(data);
+        }
+      }
+      catch (e) {
+        return res.status(500).send({ message: 'Internal server error', error: e });
+      }
+
+      res.end();
+    },
+  ];
+};

+ 60 - 0
apps/app/src/server/routes/apiv3/openai/thread.ts

@@ -0,0 +1,60 @@
+import type { Request, RequestHandler } from 'express';
+import type { ValidationChain } from 'express-validator';
+import { body } from 'express-validator';
+
+import type Crowi from '~/server/crowi';
+import { openaiClient } from '~/server/service/openai';
+import loggerFactory from '~/utils/logger';
+
+import { apiV3FormValidator } from '../../../middlewares/apiv3-form-validator';
+import type { ApiV3Response } from '../interfaces/apiv3-response';
+
+const logger = loggerFactory('growi:routes:apiv3:openai:chat');
+
+type ReqBody = {
+  userMessage: string,
+  threadId?: string,
+}
+
+type Req = Request<undefined, ApiV3Response, ReqBody>
+
+type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
+
+export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
+  const accessTokenParser = require('../../../middlewares/access-token-parser')(crowi);
+  const loginRequiredStrictly = require('../../../middlewares/login-required')(crowi);
+
+  const validator: ValidationChain[] = [
+    body('threadId').optional().isString().withMessage('threadId must be string'),
+  ];
+
+  return [
+    accessTokenParser, loginRequiredStrictly, validator, apiV3FormValidator,
+    async(req: Req, res: ApiV3Response) => {
+
+      const vectorStoreId = process.env.OPENAI_VECTOR_STORE_ID;
+      if (vectorStoreId == null) {
+        return res.apiv3Err('OPENAI_VECTOR_STORE_ID is not setup', 503);
+      }
+
+      try {
+        const threadId = req.body.threadId;
+        const thread = threadId == null
+          ? await openaiClient.beta.threads.create({
+            tool_resources: {
+              file_search: {
+                vector_store_ids: [vectorStoreId],
+              },
+            },
+          })
+          : await openaiClient.beta.threads.retrieve(threadId);
+
+        return res.apiv3({ thread });
+      }
+      catch (err) {
+        logger.error(err);
+        return res.apiv3Err(err);
+      }
+    },
+  ];
+};