Yuki Takei 1 rok temu
rodzic
commit
3be91db230

+ 0 - 32
apps/app/src/features/openai/chat/stores/thread.ts

@@ -1,32 +0,0 @@
-import useSWRSubscription from 'swr/subscription';
-
-// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
-export const useSWRSUBxOpenaiRun = (threadId: string, runId: string) => {
-  return useSWRSubscription(
-    ['openaiRun', threadId, runId],
-    ([, threadId, runId], { next }) => {
-
-      // SSEを実装したサーバに接続する
-      const eventSource = new EventSource(`/_api/v3/openai/thread/${threadId}/run/${runId}/subscribe`);
-
-      // TODO: Error handling
-      // eventSource.onerror = () => {
-      // };
-
-      eventSource.onmessage = (event) => {
-        console.log({ event });
-
-        // const parsedData = JSON.parse(event.data);
-
-        // if (parsedData.event === 'error') {
-        //   next(parsedData.error);
-        //   return;
-        // }
-
-        // next(null, parsedData.data);
-      };
-
-      return () => eventSource.close();
-    },
-  );
-};

+ 0 - 77
apps/app/src/server/routes/apiv3/openai/chat.ts

@@ -1,77 +0,0 @@
-import type { Request, RequestHandler } from 'express';
-import type { ValidationChain } from 'express-validator';
-import { body } from 'express-validator';
-
-import type Crowi from '~/server/crowi';
-import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
-import { certifyAiService } from '~/server/middlewares/certify-ai-service';
-import { configManager } from '~/server/service/config-manager';
-import { openaiClient } from '~/server/service/openai';
-import { getOrCreateChatAssistant } from '~/server/service/openai/assistant';
-import loggerFactory from '~/utils/logger';
-
-
-import type { ApiV3Response } from '../interfaces/apiv3-response';
-
-const logger = loggerFactory('growi:routes:apiv3:openai:chat');
-
-type ReqBody = {
-  userMessage: string,
-  threadId?: string,
-}
-
-type Req = Request<undefined, ApiV3Response, ReqBody>
-
-type ChatHandlersFactory = (crowi: Crowi) => RequestHandler[];
-
-export const chatHandlersFactory: ChatHandlersFactory = (crowi) => {
-  const accessTokenParser = require('../../../middlewares/access-token-parser')(crowi);
-  const loginRequiredStrictly = require('../../../middlewares/login-required')(crowi);
-
-  const validator: ValidationChain[] = [
-    body('userMessage').isString().withMessage('userMessage must be string'),
-    body('threadId').optional().isString().withMessage('threadId must be string'),
-  ];
-
-  return [
-    accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
-    async(req: Req, res: ApiV3Response) => {
-      const vectorStoreId = configManager.getConfig('crowi', 'app:openaiVectorStoreId');
-      if (vectorStoreId == null) {
-        return res.apiv3Err('OPENAI_VECTOR_STORE_ID is not setup', 503);
-      }
-
-      try {
-        const assistant = await getOrCreateChatAssistant();
-
-        const threadId = req.body.threadId;
-        const thread = threadId == null
-          ? await openaiClient.beta.threads.create({
-            messages: [{ role: 'assistant', content: req.body.userMessage }],
-            tool_resources: {
-              file_search: {
-                vector_store_ids: [vectorStoreId],
-              },
-            },
-          })
-          : await openaiClient.beta.threads.retrieve(threadId);
-
-        const run = await openaiClient.beta.threads.runs.createAndPoll(thread.id, { assistant_id: assistant.id });
-
-        if (run.status === 'completed') {
-          const messages = await openaiClient.beta.threads.messages.list(run.thread_id, {
-            limit: 1,
-            order: 'desc',
-          });
-          return res.apiv3({ messages });
-        }
-
-        return res.apiv3({});
-      }
-      catch (err) {
-        logger.error(err);
-        return res.apiv3Err(err);
-      }
-    },
-  ];
-};

+ 1 - 4
apps/app/src/server/routes/apiv3/openai/index.ts

@@ -1,17 +1,14 @@
 import express from 'express';
 
-import { chatHandlersFactory } from './chat';
 import { postMessageHandlersFactory } from './message';
-import { createThreadHandlersFactory } from './thread';
 import { rebuildVectorStoreHandlersFactory } from './rebuild-vector-store';
+import { createThreadHandlersFactory } from './thread';
 
 const router = express.Router();
 
 module.exports = (crowi) => {
   router.post('/rebuild-vector-store', rebuildVectorStoreHandlersFactory(crowi));
 
-  // deprecated
-  router.post('/chat', chatHandlersFactory(crowi));
   // create thread
   router.post('/thread', createThreadHandlersFactory(crowi));
   // post message and return streaming with SSE

+ 19 - 12
apps/app/src/server/routes/apiv3/openai/message.ts

@@ -4,6 +4,7 @@ import type { Request, RequestHandler } from 'express';
 import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
 import type { AssistantStream } from 'openai/lib/AssistantStream';
+import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
 
 import type Crowi from '~/server/crowi';
 import { openaiClient } from '~/server/service/openai';
@@ -16,6 +17,7 @@ import type { ApiV3Response } from '../interfaces/apiv3-response';
 
 const logger = loggerFactory('growi:routes:apiv3:openai:chat');
 
+
 type ReqBody = {
   userMessage: string,
   threadId?: string,
@@ -60,20 +62,25 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
         return res.status(500).send(err);
       }
 
-      res.setHeader('Content-Type', 'text/event-stream;charset=utf-8');
-      res.setHeader('Cache-Control', 'no-cache, no-transform');
-      res.setHeader('X-Accel-Buffering', 'no');
+      res.writeHead(200, {
+        'Content-Type': 'text/event-stream;charset=utf-8',
+        'Cache-Control': 'no-cache, no-transform',
+      });
 
-      try {
-        res.send(stream.toReadableStream());
-      }
-      catch (e) {
-        return res.status(500).send({ message: 'Internal server error', error: e });
-      }
-      finally {
-        res.end();
-      }
+      const messageDeltaHandler = (delta: MessageDelta) => {
+        res.write(`data: ${JSON.stringify(delta)}\n\n`);
+      };
 
+      stream.on('messageDelta', messageDeltaHandler);
+      stream.once('messageDone', () => {
+        stream.off('messageDelta', messageDeltaHandler);
+        res.end();
+      });
+      stream.once('error', (err) => {
+        logger.error(err);
+        stream.off('messageDelta', messageDeltaHandler);
+        res.end();
+      });
     },
   ];
 };