Преглед изворни кода

Merge branch 'feat/155690-implement-openai-thread-model' into feat/155763-automatic-thread-deletion

Shun Miyazawa пре 1 година
родитељ
комит
510b84b449

+ 18 - 60
apps/app/src/features/openai/server/models/thread-relation.ts

@@ -11,84 +11,42 @@ const generateExpirationDate = (): Date => {
   return expirationDate;
 };
 
-
-/*
-*  Thread Model
-*/
-interface Thread {
+interface ThreadRelation {
+  userId: mongoose.Types.ObjectId;
   threadId: string;
   expiredAt: Date;
 }
 
-interface ThreadDocument extends Thread, Document {
-  updateExpiration(): Promise<void>;
-}
-
-type ThreadModel = Model<ThreadDocument>
-
-const threadSchema = new Schema<Thread, ThreadDocument, ThreadModel>({
-  threadId: {
-    type: String,
-    required: true,
-  },
-  expiredAt: {
-    type: Date,
-    required: true,
-  },
-});
-
-threadSchema.methods.updateExpiration = async function(): Promise<void> {
-  this.expiredAt = generateExpirationDate();
-  this.parent().save();
-};
-
-
-/*
-*  Thread Relation Model
-*/
-interface ThreadRelation {
-  userId: mongoose.Types.ObjectId;
-  threads: ThreadDocument[];
+interface ThreadRelationDocument extends ThreadRelation, Document {
+  updateThreadExpiration(): Promise<void>;
 }
-interface ThreadRelationDocument extends ThreadRelation, Document {}
 
 interface ThreadRelationModel extends Model<ThreadRelationDocument> {
   upsertThreadRelation(userId: string, threadId: string): Promise<void>;
   getThreadRelation(userId: string, threadId: string): Promise<ThreadRelationDocument | null>
 }
 
-const threadRelationSchema = new Schema<ThreadRelationDocument, ThreadRelationModel>({
+const schema = new Schema<ThreadRelationDocument, ThreadRelationModel>({
   userId: {
     type: Schema.Types.ObjectId,
     ref: 'User',
     required: true,
+  },
+  threadId: {
+    type: String,
+    required: true,
     unique: true,
   },
-  threads: [threadSchema],
+  expiredAt: {
+    type: Date,
+    default: generateExpirationDate,
+    required: true,
+  },
 });
 
-
-threadRelationSchema.statics.upsertThreadRelation = async function(userId: string, threadId: string): Promise<void> {
-  const expirationDate = generateExpirationDate();
-
-  await this.updateOne(
-    { userId },
-    {
-      $push: {
-        threads: {
-          threadId,
-          expiredAt: expirationDate,
-        },
-      },
-    },
-    { upsert: true },
-  );
-};
-
-
-threadRelationSchema.statics.getThreadRelation = async function(userId: string, threadId: string): Promise<ThreadRelationDocument | null> {
-  const result = await this.findOne({ userId, 'threads.threadId': threadId });
-  return result;
+schema.methods.updateThreadExpiration = async function(): Promise<void> {
+  this.expiredAt = generateExpirationDate();
+  await this.save();
 };
 
-export default getOrCreateModel<ThreadRelationDocument, ThreadRelationModel>('ThreadRelation', threadRelationSchema);
+export default getOrCreateModel<ThreadRelationDocument, ThreadRelationModel>('ThreadRelation', schema);

+ 3 - 2
apps/app/src/features/openai/server/routes/thread.ts

@@ -2,6 +2,7 @@ import type { IUserHasId } from '@growi/core/dist/interfaces';
 import type { Request, RequestHandler } from 'express';
 import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
+import { filterXSS } from 'xss';
 
 import type Crowi from '~/server/crowi';
 import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
@@ -30,9 +31,9 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
     accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
     async(req: CreateThreadReq, res: ApiV3Response) => {
       try {
-        const user = req.user;
+        const filterdThreadId = req.body.threadId != null ? filterXSS(req.body.threadId) : undefined;
         const openaiService = getOpenaiService();
-        const thread = await openaiService?.getOrCreateThread(user._id, req.body.threadId);
+        const thread = await openaiService?.getOrCreateThread(req.user._id, filterdThreadId);
         return res.apiv3({ thread });
       }
       catch (err) {

+ 5 - 6
apps/app/src/features/openai/server/services/openai.ts

@@ -48,21 +48,20 @@ class OpenaiService implements IOpenaiService {
     if (threadId == null) {
       const vectorStore = await this.getOrCreateVectorStoreForPublicScope();
       const thread = await this.client.createThread(vectorStore.vectorStoreId);
-      await ThreadRelationModel.upsertThreadRelation(userId, thread.id);
+      await ThreadRelationModel.create({ userId, threadId: thread.id });
       return thread;
     }
 
-    const threadRelation = await ThreadRelationModel.getThreadRelation(userId, threadId);
-    const threadDocument = threadRelation?.threads.find(thread => thread.threadId === threadId);
-    if (threadDocument == null) {
+    const threadRelation = await ThreadRelationModel.findOne({ threadId });
+    if (threadRelation == null) {
       return;
     }
 
     // Check if a thread entity exists
-    const thread = await this.client.retrieveThread(threadDocument.threadId);
+    const thread = await this.client.retrieveThread(threadRelation.threadId);
 
     // Update expiration date if thread entity exists
-    await threadDocument.updateExpiration();
+    await threadRelation.updateThreadExpiration();
 
     return thread;
   }