Browse Source

Refactor OpenaiService

Shun Miyazawa 1 year ago
parent
commit
954d5c2c7f
1 changed files with 39 additions and 22 deletions
  1. 39 22
      apps/app/src/features/openai/server/services/openai.ts

+ 39 - 22
apps/app/src/features/openai/server/services/openai.ts

@@ -4,8 +4,7 @@ import { Readable, Transform } from 'stream';
 import { PageGrant, isPopulated } from '@growi/core';
 import type { HydratedDocument, Types } from 'mongoose';
 import mongoose from 'mongoose';
-import type OpenAI from 'openai';
-import { toFile } from 'openai';
+import OpenAI, { toFile } from 'openai';
 
 import ThreadRelationModel from '~/features/openai/server/models/thread-relation';
 import VectorStoreModel, { VectorStoreScopeType, type VectorStoreDocument } from '~/features/openai/server/models/vector-store';
@@ -45,7 +44,7 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
-  public async getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined> {
+  public async getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
     if (vectorStoreId != null && threadId == null) {
       try {
         const thread = await this.client.createThread(vectorStoreId);
@@ -63,12 +62,23 @@ class OpenaiService implements IOpenaiService {
     }
 
     // Check if a thread entity exists
-    const thread = await this.client.retrieveThread(threadRelation.threadId);
+    // If the thread entity does not exist, the thread-relation document is deleted
+    try {
+      const thread = await this.client.retrieveThread(threadRelation.threadId);
 
-    // Update expiration date if thread entity exists
-    await threadRelation.updateThreadExpiration();
+      // Update expiration date if thread entity exists
+      await threadRelation.updateThreadExpiration();
 
-    return thread;
+      return thread;
+    }
+    catch (err) {
+      if (err instanceof OpenAI.APIError) {
+        if (err.status === 404) {
+          await threadRelation.remove();
+        }
+      }
+      throw new Error(err);
+    }
   }
 
   public async deleteExpiredThreads(limit: number): Promise<void> {
@@ -78,9 +88,8 @@ class OpenaiService implements IOpenaiService {
     }
 
     const deletedThreadIds: string[] = [];
-    for (const expiredThreadRelation of expiredThreadRelations) {
+    for await (const expiredThreadRelation of expiredThreadRelations) {
       try {
-        // eslint-disable-next-line no-await-in-loop
         const deleteThreadResponse = await this.client.deleteThread(expiredThreadRelation.threadId);
         logger.debug('Delete thread', deleteThreadResponse);
         deletedThreadIds.push(expiredThreadRelation.threadId);
@@ -101,11 +110,20 @@ class OpenaiService implements IOpenaiService {
     }
 
     if (vectorStoreDocument != null && !isVectorStoreForPublicScopeExist) {
-      const vectorStore = await this.client.retrieveVectorStore(vectorStoreDocument.vectorStoreId);
-      if (vectorStore != null) {
+      try {
+        await this.client.retrieveVectorStore(vectorStoreDocument.vectorStoreId);
         isVectorStoreForPublicScopeExist = true;
         return vectorStoreDocument;
       }
+      catch (err) {
+        if (err instanceof OpenAI.APIError) {
+          if (err.status === 404) {
+            vectorStoreDocument.remove();
+          }
+        }
+        logger.error(err);
+        throw new Error(err);
+      }
     }
 
     const newVectorStore = await this.client.createVectorStore(VectorStoreScopeType.PUBLIC);
@@ -125,7 +143,7 @@ class OpenaiService implements IOpenaiService {
     return uploadedFile;
   }
 
-  async createVectorStoreFile(pages: Array<PageDocument>): Promise<void> {
+  async createVectorStoreFile(pages: Array<HydratedDocument<PageDocument>>): Promise<void> {
     const vectorStoreFileRelationsMap: Map<string, VectorStoreFileRelation> = new Map();
     const processUploadFile = async(page: PageDocument) => {
       if (page._id != null && page.grant === PageGrant.GRANT_PUBLIC && page.revision != null) {
@@ -163,22 +181,22 @@ class OpenaiService implements IOpenaiService {
     }
 
     try {
+      // Save vector store file relation
+      await VectorStoreFileRelationModel.upsertVectorStoreFileRelations(vectorStoreFileRelations);
+
       // Create vector store file
       const vectorStore = await this.getOrCreateVectorStoreForPublicScope();
       const createVectorStoreFileBatchResponse = await this.client.createVectorStoreFileBatch(vectorStore.vectorStoreId, uploadedFileIds);
       logger.debug('Create vector store file', createVectorStoreFileBatchResponse);
-
-      // Save vector store file relation
-      await VectorStoreFileRelationModel.upsertVectorStoreFileRelations(vectorStoreFileRelations);
     }
     catch (err) {
       logger.error(err);
 
       // Delete all uploaded files if createVectorStoreFileBatch fails
-      uploadedFileIds.forEach(async(fileId) => {
-        const deleteFileResponse = await this.client.deleteFile(fileId);
-        logger.debug('Delete vector store file (Due to createVectorStoreFileBatch failure)', deleteFileResponse);
-      });
+      const pageIds = pages.map(page => page._id);
+      for await (const pageId of pageIds) {
+        await this.deleteVectorStoreFile(pageId);
+      }
     }
 
   }
@@ -191,9 +209,8 @@ class OpenaiService implements IOpenaiService {
     }
 
     const deletedFileIds: string[] = [];
-    for (const fileId of vectorStoreFileRelation.fileIds) {
+    for await (const fileId of vectorStoreFileRelation.fileIds) {
       try {
-        // eslint-disable-next-line no-await-in-loop
         const deleteFileResponse = await this.client.deleteFile(fileId);
         logger.debug('Delete vector store file', deleteFileResponse);
         deletedFileIds.push(fileId);
@@ -225,7 +242,7 @@ class OpenaiService implements IOpenaiService {
     const createVectorStoreFile = this.createVectorStoreFile.bind(this);
     const createVectorStoreFileStream = new Transform({
       objectMode: true,
-      async transform(chunk: PageDocument[], encoding, callback) {
+      async transform(chunk: HydratedDocument<PageDocument>[], encoding, callback) {
         await createVectorStoreFile(chunk);
         this.push(chunk);
         callback();