Просмотр исходного кода

Merge remote-tracking branch 'origin/feat/growi-ai-next' into feat/unified-merge-view

Yuki Takei 1 год назад
Родитель
Сommit
0e7eaec696

+ 17 - 14
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -11,6 +11,7 @@ import SimpleBar from 'simplebar-react';
 import { apiv3Post } from '~/client/util/apiv3-client';
 import { toastError } from '~/client/util/toastr';
 import { MessageErrorCode, StreamErrorCode } from '~/features/openai/interfaces/message-error';
+import type { IThreadRelationHasId } from '~/features/openai/interfaces/thread-relation';
 import { useGrowiCloudUri } from '~/stores-universal/context';
 import loggerFactory from '~/utils/logger';
 
@@ -43,16 +44,17 @@ type FormData = {
 
 type AiAssistantChatSidebarSubstanceProps = {
   aiAssistantData: AiAssistantHasId;
-  threadId?: string;
+  threadData?: IThreadRelationHasId;
   closeAiAssistantChatSidebar: () => void
 }
 
 const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceProps> = (props: AiAssistantChatSidebarSubstanceProps) => {
   const {
-    threadId, aiAssistantData, closeAiAssistantChatSidebar,
+    aiAssistantData, threadData, closeAiAssistantChatSidebar,
   } = props;
 
-  const [currentThreadId, setCurrentThreadId] = useState<string | undefined>(threadId);
+  const [currentThreadTitle, setCurrentThreadTitle] = useState<string | undefined>(threadData?.title);
+  const [currentThreadId, setCurrentThreadId] = useState<string | undefined>(threadData?.threadId);
   const [messageLogs, setMessageLogs] = useState<Message[]>([]);
   const [generatingAnswerMessage, setGeneratingAnswerMessage] = useState<Message>();
   const [errorMessage, setErrorMessage] = useState<string | undefined>();
@@ -61,7 +63,7 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
   const { t } = useTranslation();
   const { data: growiCloudUri } = useGrowiCloudUri();
   const { trigger: mutateThreadData } = useSWRMUTxThreads(aiAssistantData._id);
-  const { trigger: mutateMessageData } = useSWRMUTxMessages(aiAssistantData._id, threadId);
+  const { trigger: mutateMessageData } = useSWRMUTxMessages(aiAssistantData._id, threadData?.threadId);
 
   const form = useForm<FormData>({
     defaultValues: {
@@ -87,10 +89,10 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
       }
     };
 
-    if (threadId != null) {
+    if (threadData != null) {
       getMessageData();
     }
-  }, [mutateMessageData, threadId]);
+  }, [mutateMessageData, threadData]);
 
   const isGenerating = generatingAnswerMessage != null;
   const submit = useCallback(async(data: FormData) => {
@@ -122,11 +124,12 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
     let currentThreadId_ = currentThreadId;
     if (currentThreadId_ == null) {
       try {
-        const res = await apiv3Post('/openai/thread', { aiAssistantId: aiAssistantData._id });
-        const thread = res.data.thread;
+        const res = await apiv3Post<IThreadRelationHasId>('/openai/thread', { aiAssistantId: aiAssistantData._id, initialUserMessage: newUserMessage.content });
+        const thread = res.data;
 
-        setCurrentThreadId(thread.id);
-        currentThreadId_ = thread.id;
+        setCurrentThreadId(thread.threadId);
+        setCurrentThreadTitle(thread.title);
+        currentThreadId_ = thread.threadId;
 
         // No need to await because data is not used
         mutateThreadData();
@@ -221,7 +224,7 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
       form.setError('input', { type: 'manual', message: err.toString() });
     }
 
-  }, [aiAssistantData._id, currentThreadId, form, growiCloudUri, isGenerating, messageLogs, mutateThreadData, t]);
+  }, [isGenerating, messageLogs, form, currentThreadId, aiAssistantData._id, mutateThreadData, t, growiCloudUri]);
 
   const keyDownHandler = (event: KeyboardEvent<HTMLTextAreaElement>) => {
     if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {
@@ -234,7 +237,7 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
       <div className="d-flex flex-column vh-100">
         <div className="d-flex align-items-center p-3 border-bottom">
           <span className="growi-custom-icons growi-ai-chat-icon me-3 fs-4">ai_assistant</span>
-          <h5 className="mb-0 fw-bold flex-grow-1 text-truncate">{aiAssistantData.name}</h5>
+          <h5 className="mb-0 fw-bold flex-grow-1 text-truncate">{currentThreadTitle ?? aiAssistantData.name}</h5>
           <button
             type="button"
             className="btn btn-link p-0 border-0"
@@ -404,7 +407,7 @@ export const AiAssistantChatSidebar: FC = memo((): JSX.Element => {
   const { data: aiAssistantChatSidebarData, close: closeAiAssistantChatSidebar } = useAiAssistantChatSidebar();
 
   const aiAssistantData = aiAssistantChatSidebarData?.aiAssistantData;
-  const threadId = aiAssistantChatSidebarData?.threadId;
+  const threadData = aiAssistantChatSidebarData?.threadData;
   const isOpened = aiAssistantChatSidebarData?.isOpened && aiAssistantData != null;
 
   useEffect(() => {
@@ -437,7 +440,7 @@ export const AiAssistantChatSidebar: FC = memo((): JSX.Element => {
         autoHide
       >
         <AiAssistantChatSidebarSubstance
-          threadId={threadId}
+          threadData={threadData}
           aiAssistantData={aiAssistantData}
           closeAiAssistantChatSidebar={closeAiAssistantChatSidebar}
         />

+ 6 - 6
apps/app/src/features/openai/client/components/AiAssistant/Sidebar/AiAssistantTree.tsx

@@ -23,7 +23,7 @@ const moduleClass = styles['ai-assistant-tree-item'] ?? '';
 type ThreadItemProps = {
   thread: IThreadRelationHasId
   aiAssistantData: AiAssistantHasId;
-  onThreadClick: (aiAssistantData: AiAssistantHasId, threadId?: string) => void;
+  onThreadClick: (aiAssistantData: AiAssistantHasId, threadData?: IThreadRelationHasId) => void;
 };
 
 const ThreadItem: React.FC<ThreadItemProps> = ({ thread, aiAssistantData, onThreadClick }) => {
@@ -33,8 +33,8 @@ const ThreadItem: React.FC<ThreadItemProps> = ({ thread, aiAssistantData, onThre
   }, []);
 
   const openChatHandler = useCallback(() => {
-    onThreadClick(aiAssistantData, thread.threadId);
-  }, [aiAssistantData, onThreadClick, thread.threadId]);
+    onThreadClick(aiAssistantData, thread);
+  }, [aiAssistantData, onThreadClick, thread]);
 
   return (
     <li
@@ -47,7 +47,7 @@ const ThreadItem: React.FC<ThreadItemProps> = ({ thread, aiAssistantData, onThre
       </div>
 
       <div className="grw-ai-assistant-title-anchor ps-1">
-        <p className="text-truncate m-auto">{thread.threadId}</p>
+        <p className="text-truncate m-auto">{thread?.title ?? 'Untitled thread'}</p>
       </div>
 
       <div className="grw-ai-assistant-actions opacity-0 d-flex justify-content-center ">
@@ -69,7 +69,7 @@ const ThreadItem: React.FC<ThreadItemProps> = ({ thread, aiAssistantData, onThre
 */
 type ThreadItemsProps = {
   aiAssistantData: AiAssistantHasId;
-  onThreadClick: (aiAssistantData: AiAssistantHasId, threadId?: string) => void;
+  onThreadClick: (aiAssistantData: AiAssistantHasId, threadData?: IThreadRelationHasId) => void;
 };
 
 const ThreadItems: React.FC<ThreadItemsProps> = ({ aiAssistantData, onThreadClick }) => {
@@ -113,7 +113,7 @@ type AiAssistantItemProps = {
   currentUserId?: string;
   aiAssistant: AiAssistantHasId;
   onEditClick: (aiAssistantData: AiAssistantHasId) => void;
-  onItemClick: (aiAssistantData: AiAssistantHasId, threadId?: string) => void;
+  onItemClick: (aiAssistantData: AiAssistantHasId, threadData?: IThreadRelationHasId) => void;
   onDeleted?: () => void;
 };
 

+ 6 - 3
apps/app/src/features/openai/client/stores/ai-assistant.tsx

@@ -7,6 +7,7 @@ import useSWRImmutable from 'swr/immutable';
 import { apiv3Get } from '~/client/util/apiv3-client';
 
 import { type AccessibleAiAssistantsHasId, type AiAssistantHasId } from '../../interfaces/ai-assistant';
+import type { IThreadRelationHasId } from '../../interfaces/thread-relation';
 
 export const AiAssistantManagementModalPageMode = {
   HOME: 'home',
@@ -57,13 +58,13 @@ export const useSWRxAiAssistants = (): SWRResponse<AccessibleAiAssistantsHasId,
 type AiAssistantChatSidebarStatus = {
   isOpened: boolean,
   aiAssistantData?: AiAssistantHasId,
-  threadId?: string,
+  threadData?: IThreadRelationHasId,
 }
 
 type AiAssistantChatSidebarUtils = {
   open(
     aiAssistantData: AiAssistantHasId,
-    threadId?: string,
+    threadData?: IThreadRelationHasId,
   ): void
   close(): void
 }
@@ -77,7 +78,9 @@ export const useAiAssistantChatSidebar = (
   return {
     ...swrResponse,
     open: useCallback(
-      (aiAssistantData: AiAssistantHasId, threadId?: string) => { swrResponse.mutate({ isOpened: true, aiAssistantData, threadId }) }, [swrResponse],
+      (aiAssistantData: AiAssistantHasId, threadData: IThreadRelationHasId) => {
+        swrResponse.mutate({ isOpened: true, aiAssistantData, threadData });
+      }, [swrResponse],
     ),
     close: useCallback(() => swrResponse.mutate({ isOpened: false }), [swrResponse]),
   };

+ 1 - 0
apps/app/src/features/openai/interfaces/thread-relation.ts

@@ -6,6 +6,7 @@ export interface IThreadRelation {
   userId: Ref<IUser>
   vectorStore: Ref<IVectorStore>
   threadId: string;
+  title?: string;
   expiredAt: Date;
 }
 

+ 3 - 0
apps/app/src/features/openai/server/models/thread-relation.ts

@@ -35,6 +35,9 @@ const schema = new Schema<ThreadRelationDocument, ThreadRelationModel>({
     required: true,
     unique: true,
   },
+  title: {
+    type: String,
+  },
   expiredAt: {
     type: Date,
     default: generateExpirationDate,

+ 5 - 3
apps/app/src/features/openai/server/routes/thread.ts

@@ -20,6 +20,7 @@ const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 type ReqBody = {
   aiAssistantId: string,
   threadId?: string,
+  initialUserMessage?: string,
 }
 
 type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
@@ -32,6 +33,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
   const validator: ValidationChain[] = [
     body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
+    body('initialUserMessage').optional().isString().withMessage('initialUserMessage must be string'),
   ];
 
   return [
@@ -44,7 +46,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
       }
 
       try {
-        const { aiAssistantId, threadId } = req.body;
+        const { aiAssistantId, threadId, initialUserMessage } = req.body;
 
         const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
         if (!isAiAssistantUsable) {
@@ -54,8 +56,8 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
         const filteredThreadId = threadId != null ? filterXSS(threadId) : undefined;
         const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
 
-        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId);
-        return res.apiv3({ thread });
+        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId, initialUserMessage);
+        return res.apiv3(thread);
       }
       catch (err) {
         logger.error(err);

+ 4 - 0
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -74,4 +74,8 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.beta.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
 
+  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+    return this.client.chat.completions.create(body);
+  }
+
 }

+ 1 - 0
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -12,4 +12,5 @@ export interface IOpenaiClientDelegator {
   uploadFile(file: Uploadable): Promise<OpenAI.Files.FileObject>
   createVectorStoreFileBatch(vectorStoreId: string, fileIds: string[]): Promise<OpenAI.Beta.VectorStores.FileBatches.VectorStoreFileBatch>
   deleteFile(fileId: string): Promise<OpenAI.Files.FileDeleted>;
+  chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion>
 }

+ 4 - 0
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -77,4 +77,8 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.beta.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
 
+  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+    return this.client.chat.completions.create(body);
+  }
+
 }

+ 50 - 5
apps/app/src/features/openai/server/services/openai.ts

@@ -62,7 +62,9 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 
 export interface IOpenaiService {
-  getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
+  getOrCreateThread(
+    userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string
+  ): Promise<ThreadRelationDocument>;
   getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
@@ -93,12 +95,55 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
-  public async getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
+  async generateThreadTitle(message: string): Promise<string | null> {
+    const model = configManager.getConfig('openai:assistantModel:chat');
+    const systemMessage = [
+      'Create a brief title (max 5 words) from your message.',
+      'Respond in the same language the user uses in their input.',
+      'Response should only contain the title.',
+    ].join('');
+
+    const threadTitleCompletion = await this.client.chatCompletion({
+      model,
+      messages: [
+        {
+          role: 'system',
+          content: systemMessage,
+        },
+        {
+          role: 'user',
+          content: message,
+        },
+      ],
+    });
+
+    const threadTitle = threadTitleCompletion.choices[0].message.content;
+    return threadTitle;
+  }
+
+  async getOrCreateThread(
+      userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string,
+  ): Promise<ThreadRelationDocument> {
     if (threadId == null) {
+      let threadTitle: string | null = null;
+      if (initialUserMessage != null) {
+        try {
+          threadTitle = await this.generateThreadTitle(initialUserMessage);
+        }
+        catch (err) {
+          logger.error(err);
+        }
+      }
+
       try {
         const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
-        await ThreadRelationModel.create({ userId, threadId: thread.id, vectorStore: vectorStoreRelation._id });
-        return thread;
+        const threadRelation = await ThreadRelationModel.create({
+          userId,
+          threadId: thread.id,
+          vectorStore: vectorStoreRelation._id,
+          title: threadTitle,
+        });
+        return threadRelation;
       }
       catch (err) {
         throw new Error(err);
@@ -118,7 +163,7 @@ class OpenaiService implements IOpenaiService {
       // Update expiration date if thread entity exists
       await threadRelation.updateThreadExpiration();
 
-      return thread;
+      return threadRelation;
     }
     catch (err) {
       await openaiApiErrorHandler(err, { notFoundError: async() => { await threadRelation.remove() } });