Browse Source

Merge pull request #9058 from weseek/imprv/openai-chat-sse

imprv: OpenAI chat by SSE
mergify[bot] 1 year ago
parent
commit
939652cca8

+ 21 - 0
apps/app/src/client/components/RagSearch/RagSearchModal.module.scss

@@ -0,0 +1,21 @@
+@use '@growi/core-styles/scss/bootstrap/init' as bs;
+@use '@growi/ui/scss/atoms/btn-muted';
+
+.rag-search-modal :global {
+
+  .textarea-ask {
+    max-height: 30vh;
+  }
+
+  .btn-submit {
+    font-size: 1.1em;
+  }
+}
+
+
+// == Colors
+.rag-search-modal :global {
+  .btn-submit {
+    @include btn-muted.colorize(bs.$purple, bs.$purple);
+  }
+}

+ 167 - 47
apps/app/src/client/components/RagSearch/RagSearchModal.tsx

@@ -1,5 +1,7 @@
-import React, { useState } from 'react';
+import type { KeyboardEvent } from 'react';
+import React, { useCallback, useEffect, useState } from 'react';
 
+import { useForm, Controller } from 'react-hook-form';
 import { Modal, ModalBody, ModalHeader } from 'reactstrap';
 
 import { apiv3Post } from '~/client/util/apiv3-client';
@@ -7,7 +9,11 @@ import { useRagSearchModal } from '~/stores/rag-search';
 import loggerFactory from '~/utils/logger';
 
 import { MessageCard } from './MessageCard';
+import { ResizableTextarea } from './ResizableTextArea';
 
+import styles from './RagSearchModal.module.scss';
+
+const moduleClass = styles['rag-search-modal'];
 
 const logger = loggerFactory('growi:clinet:components:RagSearchModal');
 
@@ -18,76 +24,190 @@ type Message = {
   isUserMessage?: boolean,
 }
 
+type FormData = {
+  input: string;
+};
+
 const RagSearchModal = (): JSX.Element => {
 
-  const [input, setInput] = useState('');
+  const form = useForm<FormData>({
+    defaultValues: {
+      input: '',
+    },
+  });
 
   const [threadId, setThreadId] = useState<string | undefined>();
-  const [messages, setMessages] = useState<Message[]>([]);
+  const [messageLogs, setMessageLogs] = useState<Message[]>([]);
+  const [lastMessage, setLastMessage] = useState<Message>();
 
   const { data: ragSearchModalData, close: closeRagSearchModal } = useRagSearchModal();
 
-  const onClickSubmitUserMessageHandler = async() => {
-    const newUserMessage = { id: messages.length.toString(), content: input, isUserMessage: true };
-    setMessages(msgs => [...msgs, newUserMessage]);
+  const isOpened = ragSearchModalData?.isOpened ?? false;
 
-    setInput('');
+  useEffect(() => {
+    // clear states when the modal is closed
+    if (!isOpened) {
+      setMessageLogs([]);
+      setThreadId(undefined);
+    }
+  }, [isOpened]);
 
-    try {
-      const res = await apiv3Post('/openai/chat', { userMessage: input, threadId });
-      const assistantMessageData = res.data.messages;
-
-      if (assistantMessageData.data.length > 0) {
-        const newMessages: Message[] = assistantMessageData.data.reverse()
-          .map((message: any) => {
-            return {
-              id: message.id,
-              content: message.content[0].text.value,
-            };
-          });
+  useEffect(() => {
+    // do nothing when the modal is closed or threadId is already set
+    if (!isOpened || threadId != null) {
+      return;
+    }
 
-        setMessages(msgs => [...msgs, ...newMessages]);
-        setThreadId(assistantMessageData.data[0].threadId);
+    const createThread = async() => {
+      // create thread
+      try {
+        const res = await apiv3Post('/openai/thread', { threadId });
+        const thread = res.data.thread;
+
+        setThreadId(thread.id);
+      }
+      catch (err) {
+        logger.error(err.toString());
+      }
+    };
+
+    createThread();
+  }, [isOpened, threadId]);
+
+  const submit = useCallback(async(data: FormData) => {
+    const { length: logLength } = messageLogs;
+
+    // post message
+    try {
+      form.clearErrors();
+
+      const response = await fetch('/_api/v3/openai/message', {
+        method: 'POST',
+        headers: { 'Content-Type': 'application/json' },
+        body: JSON.stringify({ userMessage: data.input, threadId }),
+      });
+
+      if (!response.ok) {
+        const resJson = await response.json();
+        if ('errors' in resJson) {
+          // eslint-disable-next-line @typescript-eslint/no-unused-vars
+          const errors = resJson.errors.map(({ message }) => message).join(', ');
+          form.setError('input', { type: 'manual', message: `[${response.status}] ${errors}` });
+        }
+        return;
       }
 
+      // add user message to the logs
+      const newUserMessage = { id: logLength.toString(), content: data.input, isUserMessage: true };
+      setMessageLogs(msgs => [...msgs, newUserMessage]);
+
+      // reset form
+      form.reset();
+
+      // add assistant message
+      const newAssistantMessage = { id: (logLength + 1).toString(), content: '' };
+      setLastMessage(newAssistantMessage);
+
+      const reader = response.body?.getReader();
+      const decoder = new TextDecoder('utf-8');
+
+      const read = async() => {
+        if (reader == null) return;
+
+        const { done, value } = await reader.read();
+
+        // add assistant message to the logs
+        if (done) {
+          setLastMessage((lastMessage) => {
+            if (lastMessage == null) return;
+            setMessageLogs(msgs => [...msgs, lastMessage]);
+            return undefined;
+          });
+          return;
+        }
+
+        const chunk = decoder.decode(value);
+
+        // Extract text values from the chunk
+        const textValues = chunk
+          .split('\n\n')
+          .filter(line => line.trim().startsWith('data:'))
+          .map((line) => {
+            const data = JSON.parse(line.replace('data: ', ''));
+            return data.content[0].text.value;
+          });
+
+        // append text values to the assistant message
+        setLastMessage((prevMessage) => {
+          if (prevMessage == null) return;
+          return {
+            ...prevMessage,
+            content: prevMessage.content + textValues.join(''),
+          };
+        });
+
+        read();
+      };
+      read();
     }
     catch (err) {
       logger.error(err.toString());
+      form.setError('input', { type: 'manual', message: err.toString() });
+    }
+
+  }, [form, messageLogs, threadId]);
+
+  const keyDownHandler = (event: KeyboardEvent<HTMLTextAreaElement>) => {
+    if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {
+      form.handleSubmit(submit)();
     }
   };
 
   return (
-    <Modal size="lg" isOpen={ragSearchModalData?.isOpened ?? false} toggle={closeRagSearchModal} data-testid="search-modal">
-      <ModalBody>
-        <ModalHeader tag="h4" className="mb-3 p-0">
-          <span className="material-symbols-outlined me-2 text-primary">psychology</span>
-          GROWI Assistant
-        </ModalHeader>
-
+    <Modal size="lg" isOpen={isOpened} toggle={closeRagSearchModal} className={moduleClass}>
+      <ModalHeader tag="h4" toggle={closeRagSearchModal} className="pe-4">
+        <span className="material-symbols-outlined text-primary">psychology</span>
+        GROWI Assistant
+      </ModalHeader>
+      <ModalBody className="px-lg-5 py-4">
         <div className="vstack gap-4">
-          { messages.map(message => (
+          { messageLogs.map(message => (
             <MessageCard key={message.id} right={message.isUserMessage}>{message.content}</MessageCard>
           )) }
+          { lastMessage != null && (
+            <MessageCard>{lastMessage.content}</MessageCard>
+          )}
         </div>
 
-        <div className="input-group mt-5">
-          <input
-            type="text"
-            className="form-control"
-            placeholder="お手伝いできることはありますか?"
-            aria-label="Recipient's username"
-            aria-describedby="button-addon2"
-            value={input}
-            onChange={e => setInput(e.target.value)}
-          />
-          <button
-            type="button"
-            id="button-addon2"
-            className="btn btn-outline-secondary"
-            onClick={onClickSubmitUserMessageHandler}
-          >
-            <span className="material-symbols-outlined">arrow_upward</span>
-          </button>
+        <div>
+          <form onSubmit={form.handleSubmit(submit)} className="hstack gap-2 align-items-end mt-4">
+            <Controller
+              name="input"
+              control={form.control}
+              render={({ field }) => (
+                <ResizableTextarea
+                  {...field}
+                  required
+                  className="form-control textarea-ask"
+                  style={{ resize: 'none' }}
+                  rows={1}
+                  placeholder="ききたいことを入力してください"
+                  onKeyDown={keyDownHandler}
+                />
+              )}
+            />
+            <button
+              type="submit"
+              className="btn btn-submit no-border"
+              disabled={form.formState.isSubmitting}
+            >
+              <span className="material-symbols-outlined">send</span>
+            </button>
+          </form>
+
+          {form.formState.errors.input != null && (
+            <span className="text-danger small">{form.formState.errors.input?.message}</span>
+          )}
         </div>
       </ModalBody>
     </Modal>

+ 22 - 0
apps/app/src/client/components/RagSearch/ResizableTextArea.tsx

@@ -0,0 +1,22 @@
+import type { ChangeEventHandler, DetailedHTMLProps, TextareaHTMLAttributes } from 'react';
+import { useCallback } from 'react';
+
+type Props = DetailedHTMLProps<TextareaHTMLAttributes<HTMLTextAreaElement>, HTMLTextAreaElement>;
+
+export const ResizableTextarea = (props: Props): JSX.Element => {
+
+  const { onChange: _onChange, ...rest } = props;
+
+  const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback((e) => {
+    _onChange?.(e);
+
+    // auto resize
+    // refs: https://zenn.dev/soma3134/articles/1e2fb0eab75b2d
+    e.target.style.height = 'auto';
+    e.target.style.height = `${e.target.scrollHeight + 4}px`;
+  }, [_onChange]);
+
+  return (
+    <textarea onChange={onChange} {...rest} />
+  );
+};

+ 8 - 2
apps/app/src/server/routes/apiv3/openai/index.ts

@@ -1,12 +1,18 @@
 import express from 'express';
 
-import { chatHandlersFactory } from './chat';
+import { postMessageHandlersFactory } from './message';
 import { rebuildVectorStoreHandlersFactory } from './rebuild-vector-store';
+import { createThreadHandlersFactory } from './thread';
 
 const router = express.Router();
 
 module.exports = (crowi) => {
-  router.post('/chat', chatHandlersFactory(crowi));
   router.post('/rebuild-vector-store', rebuildVectorStoreHandlersFactory(crowi));
+
+  // create thread
+  router.post('/thread', createThreadHandlersFactory(crowi));
+  // post message and return streaming with SSE
+  router.post('/message', postMessageHandlersFactory(crowi));
+
   return router;
 };

+ 85 - 0
apps/app/src/server/routes/apiv3/openai/message.ts

@@ -0,0 +1,85 @@
+import assert from 'assert';
+
+import type { Request, RequestHandler, Response } from 'express';
+import type { ValidationChain } from 'express-validator';
+import { body } from 'express-validator';
+import type { AssistantStream } from 'openai/lib/AssistantStream';
+import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
+
+import type Crowi from '~/server/crowi';
+import { openaiClient } from '~/server/service/openai';
+import { getOrCreateChatAssistant } from '~/server/service/openai/assistant';
+import loggerFactory from '~/utils/logger';
+
+import { apiV3FormValidator } from '../../../middlewares/apiv3-form-validator';
+
+
+const logger = loggerFactory('growi:routes:apiv3:openai:chat');
+
+
+type ReqBody = {
+  userMessage: string,
+  threadId?: string,
+}
+
+type Req = Request<undefined, Response, ReqBody>
+
+type PostMessageHandlersFactory = (crowi: Crowi) => RequestHandler[];
+
+export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) => {
+  const accessTokenParser = require('../../../middlewares/access-token-parser')(crowi);
+  const loginRequiredStrictly = require('../../../middlewares/login-required')(crowi);
+
+  const validator: ValidationChain[] = [
+    body('userMessage').isString().withMessage('userMessage must be string'),
+    body('threadId').isString().withMessage('threadId must be string'),
+  ];
+
+  return [
+    accessTokenParser, loginRequiredStrictly, validator, apiV3FormValidator,
+    async(req: Req, res: Response) => {
+
+      const threadId = req.body.threadId;
+
+      assert(threadId != null);
+
+      let stream: AssistantStream;
+
+      try {
+        const assistant = await getOrCreateChatAssistant();
+
+        const thread = await openaiClient.beta.threads.retrieve(threadId);
+
+        stream = openaiClient.beta.threads.runs.stream(thread.id, {
+          assistant_id: assistant.id,
+          additional_messages: [{ role: 'assistant', content: req.body.userMessage }],
+        });
+
+      }
+      catch (err) {
+        logger.error(err);
+        return res.status(500).send(err);
+      }
+
+      res.writeHead(200, {
+        'Content-Type': 'text/event-stream;charset=utf-8',
+        'Cache-Control': 'no-cache, no-transform',
+      });
+
+      const messageDeltaHandler = (delta: MessageDelta) => {
+        res.write(`data: ${JSON.stringify(delta)}\n\n`);
+      };
+
+      stream.on('messageDelta', messageDeltaHandler);
+      stream.once('messageDone', () => {
+        stream.off('messageDelta', messageDeltaHandler);
+        res.end();
+      });
+      stream.once('error', (err) => {
+        logger.error(err);
+        stream.off('messageDelta', messageDeltaHandler);
+        res.end();
+      });
+    },
+  ];
+};

+ 10 - 29
apps/app/src/server/routes/apiv3/openai/chat.ts → apps/app/src/server/routes/apiv3/openai/thread.ts

@@ -3,51 +3,42 @@ import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
 
 import type Crowi from '~/server/crowi';
-import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
-import { certifyAiService } from '~/server/middlewares/certify-ai-service';
-import { configManager } from '~/server/service/config-manager';
 import { openaiClient } from '~/server/service/openai';
-import { getOrCreateChatAssistant } from '~/server/service/openai/assistant';
 import loggerFactory from '~/utils/logger';
 
-
+import { apiV3FormValidator } from '../../../middlewares/apiv3-form-validator';
 import type { ApiV3Response } from '../interfaces/apiv3-response';
 
 const logger = loggerFactory('growi:routes:apiv3:openai:chat');
 
-type ReqBody = {
+type CreateThreadReq = Request<undefined, ApiV3Response, {
   userMessage: string,
   threadId?: string,
-}
-
-type Req = Request<undefined, ApiV3Response, ReqBody>
+}>
 
-type ChatHandlersFactory = (crowi: Crowi) => RequestHandler[];
+type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
 
-export const chatHandlersFactory: ChatHandlersFactory = (crowi) => {
+export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
   const accessTokenParser = require('../../../middlewares/access-token-parser')(crowi);
   const loginRequiredStrictly = require('../../../middlewares/login-required')(crowi);
 
   const validator: ValidationChain[] = [
-    body('userMessage').isString().withMessage('userMessage must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
   ];
 
   return [
-    accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
-    async(req: Req, res: ApiV3Response) => {
-      const vectorStoreId = configManager.getConfig('crowi', 'app:openaiVectorStoreId');
+    accessTokenParser, loginRequiredStrictly, validator, apiV3FormValidator,
+    async(req: CreateThreadReq, res: ApiV3Response) => {
+
+      const vectorStoreId = process.env.OPENAI_VECTOR_STORE_ID;
       if (vectorStoreId == null) {
         return res.apiv3Err('OPENAI_VECTOR_STORE_ID is not setup', 503);
       }
 
       try {
-        const assistant = await getOrCreateChatAssistant();
-
         const threadId = req.body.threadId;
         const thread = threadId == null
           ? await openaiClient.beta.threads.create({
-            messages: [{ role: 'assistant', content: req.body.userMessage }],
             tool_resources: {
               file_search: {
                 vector_store_ids: [vectorStoreId],
@@ -56,17 +47,7 @@ export const chatHandlersFactory: ChatHandlersFactory = (crowi) => {
           })
           : await openaiClient.beta.threads.retrieve(threadId);
 
-        const run = await openaiClient.beta.threads.runs.createAndPoll(thread.id, { assistant_id: assistant.id });
-
-        if (run.status === 'completed') {
-          const messages = await openaiClient.beta.threads.messages.list(run.thread_id, {
-            limit: 1,
-            order: 'desc',
-          });
-          return res.apiv3({ messages });
-        }
-
-        return res.apiv3({});
+        return res.apiv3({ thread });
       }
       catch (err) {
         logger.error(err);