Просмотр исходного кода

Implement thread relation model

Shun Miyazawa 1 год назад
Родитель
Сommit
9ac4c09230

+ 77 - 0
apps/app/src/features/openai/server/models/thread-relation.ts

@@ -0,0 +1,77 @@
+import type mongoose from 'mongoose';
+import { type Model, type Document, Schema } from 'mongoose';
+
+import { getOrCreateModel } from '~/server/util/mongoose-utils';
+
+const DAYS_UNTIL_EXPIRATION = 30;
+
+interface Thread {
+  threadId: string;
+  expiredAt: Date;
+}
+interface ThreadRelation {
+  userId: mongoose.Types.ObjectId;
+  threads: Thread[];
+}
+
+interface ThreadDocument extends ThreadRelation, Document {}
+
+interface ThreadRelationModel extends Model<ThreadDocument> {
+  upsertThreadRelation(userId: string, threadId: string): Promise<void>;
+  getThread(userId: string, threadId: string): Promise<Thread | undefined>;
+}
+
+const schema = new Schema<ThreadDocument, ThreadRelationModel>({
+  userId: {
+    type: Schema.Types.ObjectId,
+    ref: 'User',
+    required: true,
+    unique: true,
+  },
+  threads: [{
+    threadId: {
+      type: String,
+      required: true,
+    },
+    expiredAt: {
+      type: Date,
+      required: true,
+    },
+  },
+  ],
+});
+
+
+schema.statics.upsertThreadRelation = async function(userId: string, threadId: string) {
+  const currentDate = new Date();
+  const expirationDate = new Date(currentDate.setDate(currentDate.getDate() + DAYS_UNTIL_EXPIRATION));
+
+  await this.updateOne(
+    { userId },
+    {
+      $push: {
+        threads: {
+          threadId,
+          expiredAt: expirationDate,
+        },
+      },
+    },
+    { upsert: true },
+  );
+};
+
+schema.statics.getThread = async function(userId: string, threadId: string): Promise<Thread | undefined> {
+  const result = await this.findOne(
+    { userId, 'threads.threadId': threadId },
+    { threads: { $elemMatch: { threadId } } },
+  );
+
+  if (result != null && result.threads.length > 0) {
+    return {
+      threadId: result.threads[0].threadId,
+      expiredAt: result.threads[0].expiredAt,
+    };
+  }
+};
+
+export default getOrCreateModel<ThreadDocument, ThreadRelationModel>('ThreadRelation', schema);

+ 4 - 4
apps/app/src/features/openai/server/routes/thread.ts

@@ -1,3 +1,4 @@
+import type { IUserHasId } from '@growi/core/dist/interfaces';
 import type { Request, RequestHandler } from 'express';
 import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
@@ -13,9 +14,7 @@ import { certifyAiService } from './middlewares/certify-ai-service';
 
 const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 
-type CreateThreadReq = Request<undefined, ApiV3Response, {
-  threadId?: string,
-}>
+type CreateThreadReq = Request<undefined, ApiV3Response, { threadId?: string }> & { user: IUserHasId };
 
 type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
 
@@ -31,8 +30,9 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: CreateThreadReq, res: ApiV3Response) => {
       try {
+        const user = req.user;
         const openaiService = getOpenaiService();
-        const thread = await openaiService?.getOrCreateThread(req.body.threadId);
+        const thread = await openaiService?.getOrCreateThread(user._id, req.body.threadId);
         return res.apiv3({ thread });
       }
       catch (err) {

+ 15 - 7
apps/app/src/features/openai/server/services/openai.ts

@@ -7,6 +7,7 @@ import mongoose from 'mongoose';
 import type OpenAI from 'openai';
 import { toFile } from 'openai';
 
+import ThreadRelationModel from '~/features/openai/server/models/thread-relation';
 import VectorStoreModel, { VectorStoreScopeType, type VectorStoreDocument } from '~/features/openai/server/models/vector-store';
 import VectorStoreFileRelationModel, {
   type VectorStoreFileRelation,
@@ -29,7 +30,7 @@ const logger = loggerFactory('growi:service:openai');
 let isVectorStoreForPublicScopeExist = false;
 
 export interface IOpenaiService {
-  getOrCreateThread(threadId?: string): Promise<OpenAI.Beta.Threads.Thread>;
+  getOrCreateThread(userId: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
   getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   createVectorStoreFile(pages: PageDocument[]): Promise<void>;
   deleteVectorStoreFile(pageId: Types.ObjectId): Promise<void>;
@@ -43,13 +44,20 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
-  public async getOrCreateThread(threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
-    const vectorStore = await this.getOrCreateVectorStoreForPublicScope();
-    const thread = threadId == null
-      ? await this.client.createThread(vectorStore.vectorStoreId)
-      : await this.client.retrieveThread(threadId);
+  public async getOrCreateThread(userId: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined> {
+    if (threadId == null) {
+      const vectorStore = await this.getOrCreateVectorStoreForPublicScope();
+      const thread = await this.client.createThread(vectorStore.vectorStoreId);
+      await ThreadRelationModel.upsertThreadRelation(userId, thread.id);
+      return thread;
+    }
 
-    return thread;
+    const threadDocument = await ThreadRelationModel.getThread(userId, threadId);
+    if (threadDocument != null) {
+      // Check if a thread entity exists
+      const thread = await this.client.retrieveThread(threadDocument.threadId);
+      return thread;
+    }
   }
 
   public async getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument> {