Просмотр исходного кода

Implement an API to get a list of threads

Shun Miyazawa 1 год назад
Родитель
Сommit
ad362dc78c

+ 1 - 1
apps/app/src/features/openai/server/models/thread-relation.ts

@@ -17,7 +17,7 @@ interface ThreadRelation {
   expiredAt: Date;
 }
 
-interface ThreadRelationDocument extends ThreadRelation, Document {
+export interface ThreadRelationDocument extends ThreadRelation, Document {
   updateThreadExpiration(): Promise<void>;
 }
 

+ 62 - 0
apps/app/src/features/openai/server/routes/get-threads.ts

@@ -0,0 +1,62 @@
+import { type IUserHasId } from '@growi/core';
+import { ErrorV3 } from '@growi/core/dist/models';
+import type { Request, RequestHandler } from 'express';
+import { type ValidationChain, param } from 'express-validator';
+
+import type Crowi from '~/server/crowi';
+import { accessTokenParser } from '~/server/middlewares/access-token-parser';
+import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
+import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
+import loggerFactory from '~/utils/logger';
+
+import { getOpenaiService } from '../services/openai';
+
+import { certifyAiService } from './middlewares/certify-ai-service';
+
+const logger = loggerFactory('growi:routes:apiv3:openai:get-threads');
+
+type GetThreadsFactory = (crowi: Crowi) => RequestHandler[];
+
+type ReqParams = {
+  aiAssistantId: string,
+}
+
+type Req = Request<ReqParams, Response, undefined> & {
+  user: IUserHasId,
+}
+
+export const getThreadsFactory: GetThreadsFactory = (crowi) => {
+  const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
+
+  const validator: ValidationChain[] = [
+    param('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
+  ];
+
+  return [
+    accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
+    async(req: Req, res: ApiV3Response) => {
+      const openaiService = getOpenaiService();
+      if (openaiService == null) {
+        return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
+      }
+
+      try {
+        const { aiAssistantId } = req.params;
+
+        const isAiAssistantUsable = openaiService.isAiAssistantUsable(aiAssistantId, req.user);
+        if (!isAiAssistantUsable) {
+          return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
+        }
+
+        const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
+        const threads = await openaiService.getThreads(vectorStoreRelation._id);
+
+        return res.apiv3({ threads });
+      }
+      catch (err) {
+        logger.error(err);
+        return res.apiv3Err(new ErrorV3('Failed to get threads'));
+      }
+    },
+  ];
+};

+ 4 - 0
apps/app/src/features/openai/server/routes/index.ts

@@ -27,6 +27,10 @@ export const factory = (crowi: Crowi): express.Router => {
       router.post('/thread', createThreadHandlersFactory(crowi));
     });
 
+    import('./get-threads').then(({ getThreadsFactory }) => {
+      router.get('/threads/:aiAssistantId', getThreadsFactory(crowi));
+    });
+
     import('./message').then(({ postMessageHandlersFactory }) => {
       router.post('/message', postMessageHandlersFactory(crowi));
     });

+ 7 - 1
apps/app/src/features/openai/server/services/openai.ts

@@ -13,7 +13,7 @@ import mongoose, { type HydratedDocument, type Types } from 'mongoose';
 import { type OpenAI, toFile } from 'openai';
 
 import ExternalUserGroupRelation from '~/features/external-user-group/server/models/external-user-group-relation';
-import ThreadRelationModel from '~/features/openai/server/models/thread-relation';
+import ThreadRelationModel, { type ThreadRelationDocument } from '~/features/openai/server/models/thread-relation';
 import VectorStoreModel, { type VectorStoreDocument } from '~/features/openai/server/models/vector-store';
 import VectorStoreFileRelationModel, {
   type VectorStoreFileRelation,
@@ -62,6 +62,7 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 
 export interface IOpenaiService {
   getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
+  getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
@@ -117,6 +118,11 @@ class OpenaiService implements IOpenaiService {
     }
   }
 
+  async getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]> {
+    const threadRelations = await ThreadRelationModel.find({ vectorStore: vectorStoreRelationId });
+    return threadRelations;
+  }
+
   public async deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void> {
     const expiredThreadRelations = await ThreadRelationModel.getExpiredThreadRelations(limit);
     if (expiredThreadRelations == null) {