Просмотр исходного кода

Merge pull request #9659 from weseek/feat/162019-api-to-return-thread-list

feat: API to return thread list
Shun Miyazawa 1 год назад
Родитель
Сommit
e0549cc4ee

+ 3 - 2
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -94,9 +94,10 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
 
     // create thread
     let currentThreadId = threadId;
-    if (threadId == null) {
+    const aiAssistantId = aiAssistantData?._id;
+    if (threadId == null && aiAssistantId) {
       try {
-        const res = await apiv3Post('/openai/thread');
+        const res = await apiv3Post('/openai/thread', { aiAssistantId });
         const thread = res.data.thread;
 
         setThreadId(thread.id);

+ 45 - 28
apps/app/src/features/openai/client/components/AiAssistant/Sidebar/AiAssistantTree.tsx

@@ -3,35 +3,28 @@ import React, { useCallback, useState } from 'react';
 import { getIdStringForRef } from '@growi/core';
 
 import { toastError, toastSuccess } from '~/client/util/toastr';
+import type { IThreadRelationHasId } from '~/features/openai/interfaces/thread-relation';
 import { useCurrentUser } from '~/stores-universal/context';
 
 import type { AiAssistantAccessScope } from '../../../../interfaces/ai-assistant';
 import { AiAssistantShareScope, type AiAssistantHasId } from '../../../../interfaces/ai-assistant';
 import { deleteAiAssistant } from '../../../services/ai-assistant';
 import { useAiAssistantChatSidebar, useAiAssistantManagementModal } from '../../../stores/ai-assistant';
+import { useSWRxThreads } from '../../../stores/thread';
 
 import styles from './AiAssistantTree.module.scss';
 
 const moduleClass = styles['ai-assistant-tree-item'] ?? '';
 
-type Thread = {
-  _id: string;
-  name: string;
-}
-
-const dummyThreads: Thread[] = [
-  { _id: '1', name: 'thread1' },
-  { _id: '2', name: 'thread2' },
-  { _id: '3', name: 'thread3' },
-];
 
+/*
+*  ThreadItem
+*/
 type ThreadItemProps = {
-  thread: Thread;
+  thread: IThreadRelationHasId
 };
 
-const ThreadItem: React.FC<ThreadItemProps> = ({
-  thread,
-}) => {
+const ThreadItem: React.FC<ThreadItemProps> = ({ thread }) => {
 
   const deleteThreadHandler = useCallback(() => {
     // TODO: https://redmine.weseek.co.jp/issues/161490
@@ -52,7 +45,7 @@ const ThreadItem: React.FC<ThreadItemProps> = ({
       </div>
 
       <div className="grw-ai-assistant-title-anchor ps-1">
-        <p className="text-truncate m-auto">{thread.name}</p>
+        <p className="text-truncate m-auto">{thread.threadId}</p>
       </div>
 
       <div className="grw-ai-assistant-actions opacity-0 d-flex justify-content-center ">
@@ -69,6 +62,36 @@ const ThreadItem: React.FC<ThreadItemProps> = ({
 };
 
 
+/*
+*  ThreadItems
+*/
+type ThreadItemsProps = {
+  aiAssistantId: string;
+};
+
+const ThreadItems: React.FC<ThreadItemsProps> = ({ aiAssistantId }) => {
+  const { data: threads } = useSWRxThreads(aiAssistantId);
+
+  if (threads == null || threads.length === 0) {
+    return <p className="text-secondary ms-5">スレッドが存在しません</p>;
+  }
+
+  return (
+    <div className="grw-ai-assistant-item-children">
+      {threads.map(thread => (
+        <ThreadItem
+          key={thread._id}
+          thread={thread}
+        />
+      ))}
+    </div>
+  );
+};
+
+
+/*
+*  AiAssistantItem
+*/
 const getShareScopeIcon = (shareScope: AiAssistantShareScope, accessScope: AiAssistantAccessScope): string => {
   const determinedSharedScope = shareScope === AiAssistantShareScope.SAME_AS_ACCESS_SCOPE ? accessScope : shareScope;
   switch (determinedSharedScope) {
@@ -84,7 +107,6 @@ const getShareScopeIcon = (shareScope: AiAssistantShareScope, accessScope: AiAss
 type AiAssistantItemProps = {
   currentUserId?: string;
   aiAssistant: AiAssistantHasId;
-  threads: Thread[];
   onEditClicked?: (aiAssistantData: AiAssistantHasId) => void;
   onItemClicked?: (aiAssistantData: AiAssistantHasId) => void;
   onDeleted?: () => void;
@@ -93,7 +115,6 @@ type AiAssistantItemProps = {
 const AiAssistantItem: React.FC<AiAssistantItemProps> = ({
   currentUserId,
   aiAssistant,
-  threads,
   onEditClicked,
   onItemClicked,
   onDeleted,
@@ -185,20 +206,17 @@ const AiAssistantItem: React.FC<AiAssistantItemProps> = ({
         )}
       </li>
 
-      {isThreadsOpened && threads.length > 0 && (
-        <div className="grw-ai-assistant-item-children">
-          {threads.map(thread => (
-            <ThreadItem
-              key={thread._id}
-              thread={thread}
-            />
-          ))}
-        </div>
-      )}
+      { isThreadsOpened && (
+        <ThreadItems aiAssistantId={aiAssistant._id} />
+      ) }
     </>
   );
 };
 
+
+/*
+*  AiAssistantTree
+*/
 type AiAssistantTreeProps = {
   aiAssistants: AiAssistantHasId[];
   onDeleted?: () => void;
@@ -216,7 +234,6 @@ export const AiAssistantTree: React.FC<AiAssistantTreeProps> = ({ aiAssistants,
           key={assistant._id}
           currentUserId={currentUser?._id}
           aiAssistant={assistant}
-          threads={dummyThreads}
           onEditClicked={openAiAssistantManagementModal}
           onItemClicked={openAiAssistantChatSidebar}
           onDeleted={onDeleted}

+ 15 - 0
apps/app/src/features/openai/client/stores/thread.tsx

@@ -0,0 +1,15 @@
+import { type SWRResponse } from 'swr';
+import useSWRImmutable from 'swr/immutable';
+
+import { apiv3Get } from '~/client/util/apiv3-client';
+
+import type { IThreadRelationHasId } from '../../interfaces/thread-relation';
+
+export const useSWRxThreads = (aiAssistantId: string): SWRResponse<IThreadRelationHasId[], Error> => {
+  const key = [`openai/threads/${aiAssistantId}`];
+
+  return useSWRImmutable<IThreadRelationHasId[]>(
+    key,
+    ([endpoint]) => apiv3Get(endpoint).then(response => response.data.threads),
+  );
+};

+ 2 - 2
apps/app/src/features/openai/interfaces/ai-assistant.ts

@@ -2,7 +2,7 @@ import type {
   IGrantedGroup, IUser, Ref, HasObjectId,
 } from '@growi/core';
 
-import type { VectorStore } from '../server/models/vector-store';
+import type { IVectorStore } from './vector-store';
 
 /*
 *  Objects
@@ -31,7 +31,7 @@ export interface AiAssistant {
   description: string
   additionalInstruction: string
   pagePathPatterns: string[],
-  vectorStore: Ref<VectorStore>
+  vectorStore: Ref<IVectorStore>
   owner: Ref<IUser>
   grantedGroupsForShareScope?: IGrantedGroup[]
   grantedGroupsForAccessScope?: IGrantedGroup[]

+ 12 - 0
apps/app/src/features/openai/interfaces/thread-relation.ts

@@ -0,0 +1,12 @@
+import type { IUser, Ref, HasObjectId } from '@growi/core';
+
+import type { IVectorStore } from './vector-store';
+
+export interface IThreadRelation {
+  userId: Ref<IUser>
+  vectorStore: Ref<IVectorStore>
+  threadId: string;
+  expiredAt: Date;
+}
+
+export type IThreadRelationHasId = IThreadRelation & HasObjectId;

+ 4 - 0
apps/app/src/features/openai/interfaces/vector-store.ts

@@ -0,0 +1,4 @@
+export interface IVectorStore {
+  vectorStoreId: string
+  isDeleted: boolean
+}

+ 8 - 8
apps/app/src/features/openai/server/models/thread-relation.ts

@@ -1,22 +1,17 @@
 import { addDays } from 'date-fns';
-import type mongoose from 'mongoose';
 import { type Model, type Document, Schema } from 'mongoose';
 
 import { getOrCreateModel } from '~/server/util/mongoose-utils';
 
+import type { IThreadRelation } from '../../interfaces/thread-relation';
+
 const DAYS_UNTIL_EXPIRATION = 3;
 
 const generateExpirationDate = (): Date => {
   return addDays(new Date(), DAYS_UNTIL_EXPIRATION);
 };
 
-interface ThreadRelation {
-  userId: mongoose.Types.ObjectId;
-  threadId: string;
-  expiredAt: Date;
-}
-
-interface ThreadRelationDocument extends ThreadRelation, Document {
+export interface ThreadRelationDocument extends IThreadRelation, Document {
   updateThreadExpiration(): Promise<void>;
 }
 
@@ -30,6 +25,11 @@ const schema = new Schema<ThreadRelationDocument, ThreadRelationModel>({
     ref: 'User',
     required: true,
   },
+  vectorStore: {
+    type: Schema.Types.ObjectId,
+    ref: 'VectorStore',
+    required: true,
+  },
   threadId: {
     type: String,
     required: true,

+ 3 - 6
apps/app/src/features/openai/server/models/vector-store.ts

@@ -2,16 +2,13 @@ import { type Model, type Document, Schema } from 'mongoose';
 
 import { getOrCreateModel } from '~/server/util/mongoose-utils';
 
-export interface VectorStore {
-  vectorStoreId: string
-  isDeleted: boolean
-}
+import type { IVectorStore } from '../../interfaces/vector-store';
 
-export interface VectorStoreDocument extends VectorStore, Document {
+export interface VectorStoreDocument extends IVectorStore, Document {
   markAsDeleted(): Promise<void>
 }
 
-type VectorStoreModel = Model<VectorStore>
+type VectorStoreModel = Model<VectorStoreDocument>;
 
 const schema = new Schema<VectorStoreDocument, VectorStoreModel>({
   vectorStoreId: {

+ 62 - 0
apps/app/src/features/openai/server/routes/get-threads.ts

@@ -0,0 +1,62 @@
+import { type IUserHasId } from '@growi/core';
+import { ErrorV3 } from '@growi/core/dist/models';
+import type { Request, RequestHandler } from 'express';
+import { type ValidationChain, param } from 'express-validator';
+
+import type Crowi from '~/server/crowi';
+import { accessTokenParser } from '~/server/middlewares/access-token-parser';
+import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
+import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
+import loggerFactory from '~/utils/logger';
+
+import { getOpenaiService } from '../services/openai';
+
+import { certifyAiService } from './middlewares/certify-ai-service';
+
+const logger = loggerFactory('growi:routes:apiv3:openai:get-threads');
+
+type GetThreadsFactory = (crowi: Crowi) => RequestHandler[];
+
+type ReqParams = {
+  aiAssistantId: string,
+}
+
+type Req = Request<ReqParams, Response, undefined> & {
+  user: IUserHasId,
+}
+
+export const getThreadsFactory: GetThreadsFactory = (crowi) => {
+  const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
+
+  const validator: ValidationChain[] = [
+    param('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
+  ];
+
+  return [
+    accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
+    async(req: Req, res: ApiV3Response) => {
+      const openaiService = getOpenaiService();
+      if (openaiService == null) {
+        return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
+      }
+
+      try {
+        const { aiAssistantId } = req.params;
+
+        const isAiAssistantUsable = openaiService.isAiAssistantUsable(aiAssistantId, req.user);
+        if (!isAiAssistantUsable) {
+          return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
+        }
+
+        const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
+        const threads = await openaiService.getThreads(vectorStoreRelation._id);
+
+        return res.apiv3({ threads });
+      }
+      catch (err) {
+        logger.error(err);
+        return res.apiv3Err(new ErrorV3('Failed to get threads'));
+      }
+    },
+  ];
+};

+ 4 - 0
apps/app/src/features/openai/server/routes/index.ts

@@ -27,6 +27,10 @@ export const factory = (crowi: Crowi): express.Router => {
       router.post('/thread', createThreadHandlersFactory(crowi));
     });
 
+    import('./get-threads').then(({ getThreadsFactory }) => {
+      router.get('/threads/:aiAssistantId', getThreadsFactory(crowi));
+    });
+
     import('./message').then(({ postMessageHandlersFactory }) => {
       router.post('/message', postMessageHandlersFactory(crowi));
     });

+ 19 - 5
apps/app/src/features/openai/server/routes/thread.ts

@@ -17,7 +17,12 @@ import { certifyAiService } from './middlewares/certify-ai-service';
 
 const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 
-type CreateThreadReq = Request<undefined, ApiV3Response, { threadId?: string }> & { user: IUserHasId };
+type ReqBody = {
+  aiAssistantId: string,
+  threadId?: string,
+}
+
+type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
 
 type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
 
@@ -25,6 +30,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
   const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
 
   const validator: ValidationChain[] = [
+    body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
   ];
 
@@ -38,10 +44,18 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
       }
 
       try {
-        const filterdThreadId = req.body.threadId != null ? filterXSS(req.body.threadId) : undefined;
-        // const vectorStore = await openaiService?.getOrCreateVectorStoreForPublicScope();
-        // const thread = await openaiService?.getOrCreateThread(req.user._id, vectorStore?.vectorStoreId, filterdThreadId);
-        return res.apiv3({ });
+        const { aiAssistantId, threadId } = req.body;
+
+        const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
+        if (!isAiAssistantUsable) {
+          return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
+        }
+
+        const filteredThreadId = threadId != null ? filterXSS(threadId) : undefined;
+        const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
+
+        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId);
+        return res.apiv3({ thread });
       }
       catch (err) {
         logger.error(err);

+ 3 - 0
apps/app/src/features/openai/server/services/normalize-data/normalize-thread-relation-expired-at/normalize-thread-relation-expired-at.integ.ts

@@ -15,6 +15,7 @@ describe('normalizeExpiredAtForThreadRelations', () => {
     const threadRelation = new ThreadRelation({
       userId: new Types.ObjectId(),
       threadId: 'test-thread',
+      vectorStore: new Types.ObjectId(),
       expiredAt: expiredDate,
     });
     await threadRelation.save();
@@ -36,6 +37,7 @@ describe('normalizeExpiredAtForThreadRelations', () => {
     const threadRelation = new ThreadRelation({
       userId: new Types.ObjectId(),
       threadId: 'test-thread-2',
+      vectorStore: new Types.ObjectId(),
       expiredAt: nonExpiredDate,
     });
     await threadRelation.save();
@@ -55,6 +57,7 @@ describe('normalizeExpiredAtForThreadRelations', () => {
     const threadRelation = new ThreadRelation({
       userId: new Types.ObjectId(),
       threadId: 'test-thread-3',
+      vectorStore: new Types.ObjectId(),
       expiredAt: nonExpiredDate,
     });
     await threadRelation.save();

+ 60 - 6
apps/app/src/features/openai/server/services/openai.ts

@@ -13,7 +13,7 @@ import mongoose, { type HydratedDocument, type Types } from 'mongoose';
 import { type OpenAI, toFile } from 'openai';
 
 import ExternalUserGroupRelation from '~/features/external-user-group/server/models/external-user-group-relation';
-import ThreadRelationModel from '~/features/openai/server/models/thread-relation';
+import ThreadRelationModel, { type ThreadRelationDocument } from '~/features/openai/server/models/thread-relation';
 import VectorStoreModel, { type VectorStoreDocument } from '~/features/openai/server/models/vector-store';
 import VectorStoreFileRelationModel, {
   type VectorStoreFileRelation,
@@ -61,16 +61,20 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 
 
 export interface IOpenaiService {
-  getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
+  getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
+  getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
+  getVectorStoreRelation(aiAssistantId: string): Promise<VectorStoreDocument>
   getVectorStoreRelationsByPageIds(pageId: Types.ObjectId[]): Promise<VectorStoreDocument[]>;
   createVectorStoreFile(vectorStoreRelation: VectorStoreDocument, pages: PageDocument[]): Promise<void>;
   deleteVectorStoreFile(vectorStoreRelationId: Types.ObjectId, pageId: Types.ObjectId): Promise<void>;
   deleteVectorStoreFilesByPageIds(pageIds: Types.ObjectId[]): Promise<void>;
   deleteObsoleteVectorStoreFile(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   // rebuildVectorStoreAll(): Promise<void>;
+  // rebuildVectorStore(page: HydratedDocument<PageDocument>): Promise<void>;
+  isAiAssistantUsable(aiAssistantId: string, user: IUserHasId): Promise<boolean>;
   updateVectorStore(page: HydratedDocument<PageDocument>): Promise<void>;
   createAiAssistant(data: Omit<AiAssistant, 'vectorStore'>): Promise<AiAssistantDocument>;
   updateAiAssistant(aiAssistantId: string, data: Omit<AiAssistant, 'vectorStore'>): Promise<AiAssistantDocument>;
@@ -84,11 +88,11 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
-  public async getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
-    if (vectorStoreId != null && threadId == null) {
+  public async getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
+    if (threadId == null) {
       try {
-        const thread = await this.client.createThread(vectorStoreId);
-        await ThreadRelationModel.create({ userId, threadId: thread.id });
+        const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
+        await ThreadRelationModel.create({ userId, threadId: thread.id, vectorStore: vectorStoreRelation._id });
         return thread;
       }
       catch (err) {
@@ -117,6 +121,11 @@ class OpenaiService implements IOpenaiService {
     }
   }
 
+  async getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]> {
+    const threadRelations = await ThreadRelationModel.find({ vectorStore: vectorStoreRelationId });
+    return threadRelations;
+  }
+
   public async deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void> {
     const expiredThreadRelations = await ThreadRelationModel.getExpiredThreadRelations(limit);
     if (expiredThreadRelations == null) {
@@ -174,6 +183,15 @@ class OpenaiService implements IOpenaiService {
   //   return newVectorStoreDocument;
   // }
 
+  async getVectorStoreRelation(aiAssistantId: string): Promise<VectorStoreDocument> {
+    const aiAssistant = await AiAssistantModel.findById({ _id: aiAssistantId }).populate('vectorStore');
+    if (aiAssistant == null) {
+      throw createError(404, 'AiAssistant document does not exist');
+    }
+
+    return aiAssistant.vectorStore as VectorStoreDocument;
+  }
+
   async getVectorStoreRelationsByPageIds(pageIds: Types.ObjectId[]): Promise<VectorStoreDocument[]> {
     const pipeline = [
       // Stage 1: Match documents with the given pageId
@@ -615,6 +633,42 @@ class OpenaiService implements IOpenaiService {
     }
   }
 
+  async isAiAssistantUsable(aiAssistantId: string, user: IUserHasId): Promise<boolean> {
+    const aiAssistant = await AiAssistantModel.findById(aiAssistantId);
+
+    if (aiAssistant == null) {
+      throw createError(404, 'AiAssistant document does not exist');
+    }
+
+    const isOwner = getIdStringForRef(aiAssistant.owner) === getIdStringForRef(user._id);
+
+    if (aiAssistant.shareScope === AiAssistantShareScope.PUBLIC_ONLY) {
+      return true;
+    }
+
+    if ((aiAssistant.shareScope === AiAssistantShareScope.OWNER) && isOwner) {
+      return true;
+    }
+
+    if ((aiAssistant.shareScope === AiAssistantShareScope.SAME_AS_ACCESS_SCOPE) && (aiAssistant.accessScope === AiAssistantAccessScope.OWNER) && isOwner) {
+      return true;
+    }
+
+    if ((aiAssistant.shareScope === AiAssistantShareScope.GROUPS)
+      || ((aiAssistant.shareScope === AiAssistantShareScope.SAME_AS_ACCESS_SCOPE) && (aiAssistant.accessScope === AiAssistantAccessScope.GROUPS))) {
+      const userGroupIds = [
+        ...(await UserGroupRelation.findAllUserGroupIdsRelatedToUser(user)),
+        ...(await ExternalUserGroupRelation.findAllUserGroupIdsRelatedToUser(user)),
+      ].map(group => group.toString());
+
+      const grantedGroupIdsForShareScope = aiAssistant.grantedGroupsForShareScope?.map(group => getIdStringForRef(group.item)) ?? [];
+      const isShared = userGroupIds.some(userGroupId => grantedGroupIdsForShareScope.includes(userGroupId));
+      return isShared;
+    }
+
+    return false;
+  }
+
   async createAiAssistant(data: Omit<AiAssistant, 'vectorStore'>): Promise<AiAssistantDocument> {
     await this.validateGrantedUserGroupsForAiAssistant(
       data.owner,