Browse Source

Delegate creation or retrieval of threads to service methods

Shun Miyazawa 1 year ago
parent
commit
9683877799

+ 2 - 18
apps/app/src/features/openai/server/routes/thread.ts

@@ -7,7 +7,6 @@ import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
 import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
 import loggerFactory from '~/utils/logger';
 
-import { openaiClient } from '../services';
 import { getOpenaiService } from '../services/openai';
 
 import { certifyAiService } from './middlewares/certify-ai-service';
@@ -32,24 +31,9 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
   return [
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: CreateThreadReq, res: ApiV3Response) => {
-      const openaiService = getOpenaiService();
-      if (openaiService == null) {
-        return res.apiv3Err('OpenaiService is not available', 503);
-      }
-
       try {
-        const vectorStore = await openaiService.getOrCreateVectorStoreForPublicScope();
-        const threadId = req.body.threadId;
-        const thread = threadId == null
-          ? await openaiClient.beta.threads.create({
-            tool_resources: {
-              file_search: {
-                vector_store_ids: [vectorStore.vectorStoreId],
-              },
-            },
-          })
-          : await openaiClient.beta.threads.retrieve(threadId);
-
+        const openaiService = getOpenaiService();
+        const thread = await openaiService?.getOrCreateThread(req.body.threadId);
         return res.apiv3({ thread });
       }
       catch (err) {

+ 14 - 0
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -22,6 +22,20 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     // TODO: initialize openaiVectorStoreId property
   }
 
+  async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.create({
+      tool_resources: {
+        file_search: {
+          vector_store_ids: [vectorStoreId],
+        },
+      },
+    });
+  }
+
+  async retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.retrieve(threadId);
+  }
+
   async createVectorStore(scopeType:VectorStoreScopeType): Promise<OpenAI.Beta.VectorStores.VectorStore> {
     return this.client.beta.vectorStores.create({ name: `growi-vector-store-{${scopeType}` });
   }

+ 2 - 0
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -4,6 +4,8 @@ import type { Uploadable } from 'openai/uploads';
 import type { VectorStoreScopeType } from '~/features/openai/server/models/vector-store';
 
 export interface IOpenaiClientDelegator {
+  createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
+  retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread>
   retrieveVectorStore(vectorStoreId: string): Promise<OpenAI.Beta.VectorStores.VectorStore>
   createVectorStore(scopeType:VectorStoreScopeType): Promise<OpenAI.Beta.VectorStores.VectorStore>
   uploadFile(file: Uploadable): Promise<OpenAI.Files.FileObject>

+ 14 - 0
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -24,6 +24,20 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     this.client = new OpenAI({ apiKey });
   }
 
+  async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.create({
+      tool_resources: {
+        file_search: {
+          vector_store_ids: [vectorStoreId],
+        },
+      },
+    });
+  }
+
+  async retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.retrieve(threadId);
+  }
+
   async createVectorStore(scopeType:VectorStoreScopeType): Promise<OpenAI.Beta.VectorStores.VectorStore> {
     return this.client.beta.vectorStores.create({ name: `growi-vector-store-${scopeType}` });
   }

+ 10 - 0
apps/app/src/features/openai/server/services/openai.ts

@@ -29,6 +29,7 @@ const logger = loggerFactory('growi:service:openai');
 let isVectorStoreForPublicScopeExist = false;
 
 export interface IOpenaiService {
+  getOrCreateThread(threadId?: string): Promise<OpenAI.Beta.Threads.Thread>;
   getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   createVectorStoreFile(pages: PageDocument[]): Promise<void>;
   deleteVectorStoreFile(pageId: Types.ObjectId): Promise<void>;
@@ -42,6 +43,15 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
+  public async getOrCreateThread(threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
+    const vectorStore = await this.getOrCreateVectorStoreForPublicScope();
+    const thread = threadId == null
+      ? await this.client.createThread(vectorStore.vectorStoreId)
+      : await this.client.retrieveThread(threadId);
+
+    return thread;
+  }
+
   public async getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument> {
     const vectorStoreDocument = await VectorStoreModel.findOne({ scorpeType: VectorStoreScopeType.PUBLIC });