Просмотр исходного кода

Merge pull request #9664 from weseek/feat/162041-implement-api-to-fetch-thread-messages

feat: Implement api to fetch thread messages
Shun Miyazawa 1 год назад
Родитель
Сommit
2038bb87b9

+ 72 - 0
apps/app/src/features/openai/server/routes/get-messages.ts

@@ -0,0 +1,72 @@
+import { type IUserHasId } from '@growi/core';
+import { ErrorV3 } from '@growi/core/dist/models';
+import type { Request, RequestHandler } from 'express';
+import { type ValidationChain, param } from 'express-validator';
+
+import type Crowi from '~/server/crowi';
+import { accessTokenParser } from '~/server/middlewares/access-token-parser';
+import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
+import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
+import loggerFactory from '~/utils/logger';
+
+import { getOpenaiService } from '../services/openai';
+
+import { certifyAiService } from './middlewares/certify-ai-service';
+
+const logger = loggerFactory('growi:routes:apiv3:openai:get-message');
+
+type GetMessagesFactory = (crowi: Crowi) => RequestHandler[];
+
+type ReqParam = {
+  threadId: string,
+  aiAssistantId: string,
+  before?: string,
+  after?: string,
+  limit?: number,
+}
+
+type Req = Request<ReqParam, Response, undefined> & {
+  user: IUserHasId,
+}
+
+export const getMessagesFactory: GetMessagesFactory = (crowi) => {
+  const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
+
+  const validator: ValidationChain[] = [
+    param('threadId').isString().withMessage('threadId must be string'),
+    param('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
+    param('limit').optional().isInt().withMessage('limit must be integer'),
+    param('before').optional().isString().withMessage('before must be string'),
+    param('after').optional().isString().withMessage('after must be string'),
+  ];
+
+  return [
+    accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
+    async(req: Req, res: ApiV3Response) => {
+      const openaiService = getOpenaiService();
+      if (openaiService == null) {
+        return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
+      }
+
+      try {
+        const {
+          threadId, aiAssistantId, limit, before, after,
+        } = req.params;
+
+        const isAiAssistantUsable = openaiService.isAiAssistantUsable(aiAssistantId, req.user);
+        if (!isAiAssistantUsable) {
+          return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
+        }
+
+        const options = { limit, before, after };
+        const messages = await openaiService.getMessageData(threadId, req.user.lang, options);
+
+        return res.apiv3({ messages });
+      }
+      catch (err) {
+        logger.error(err);
+        return res.apiv3Err(new ErrorV3('Failed to get messages'));
+      }
+    },
+  ];
+};

+ 4 - 0
apps/app/src/features/openai/server/routes/index.ts

@@ -35,6 +35,10 @@ export const factory = (crowi: Crowi): express.Router => {
       router.post('/message', postMessageHandlersFactory(crowi));
     });
 
+    import('./get-messages').then(({ getMessagesFactory }) => {
+      router.get('/messages/:aiAssistantId/:threadId', getMessagesFactory(crowi));
+    });
+
     import('./ai-assistant').then(({ createAiAssistantFactory }) => {
       router.post('/ai-assistant', createAiAssistantFactory(crowi));
     });

+ 8 - 0
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -38,6 +38,14 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.beta.threads.del(threadId);
   }
 
+  async getMessages(threadId: string, options?: { before: string, after: string, limit: number }): Promise<OpenAI.Beta.Threads.Messages.MessagesPage> {
+    return this.client.beta.threads.messages.list(threadId, {
+      limit: options?.limit,
+      before: options?.before,
+      after: options?.after,
+    });
+  }
+
   async createVectorStore(name: string): Promise<OpenAI.Beta.VectorStores.VectorStore> {
     return this.client.beta.vectorStores.create({ name: `growi-vector-store-for-${name}` });
   }

+ 1 - 0
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -5,6 +5,7 @@ export interface IOpenaiClientDelegator {
   createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
   retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread>
   deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted>
+  getMessages(threadId: string, options?: { limit: number, before: string, after: string }): Promise<OpenAI.Beta.Threads.Messages.MessagesPage>
   retrieveVectorStore(vectorStoreId: string): Promise<OpenAI.Beta.VectorStores.VectorStore>
   createVectorStore(name: string): Promise<OpenAI.Beta.VectorStores.VectorStore>
   deleteVectorStore(vectorStoreId: string): Promise<OpenAI.Beta.VectorStores.VectorStoreDeleted>

+ 8 - 0
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -41,6 +41,14 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.beta.threads.del(threadId);
   }
 
+  async getMessages(threadId: string, options?: { before?: string, after?: string, limit?: number }): Promise<OpenAI.Beta.Threads.Messages.MessagesPage> {
+    return this.client.beta.threads.messages.list(threadId, {
+      limit: options?.limit,
+      before: options?.before,
+      after: options?.after,
+    });
+  }
+
   async createVectorStore(name: string): Promise<OpenAI.Beta.VectorStores.VectorStore> {
     return this.client.beta.vectorStores.create({ name: `growi-vector-store-for-${name}` });
   }

+ 21 - 0
apps/app/src/features/openai/server/services/openai.ts

@@ -2,6 +2,7 @@ import assert from 'node:assert';
 import { Readable, Transform } from 'stream';
 import { pipeline } from 'stream/promises';
 
+import type { Lang } from '@growi/core';
 import {
   PageGrant, getIdForRef, getIdStringForRef, isPopulated, type IUserHasId,
 } from '@growi/core';
@@ -35,6 +36,7 @@ import { convertMarkdownToHtml } from '../utils/convert-markdown-to-html';
 import { getClient } from './client-delegator';
 // import { splitMarkdownIntoChunks } from './markdown-splitter/markdown-token-splitter';
 import { openaiApiErrorHandler } from './openai-api-error-handler';
+import { replaceAnnotationWithPageLink } from './replace-annotation-with-page-link';
 
 const { isDeepEquals } = deepEquals;
 
@@ -66,6 +68,9 @@ export interface IOpenaiService {
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
+  getMessageData(
+    threadId: string, lang?: Lang, options?: { before?: string, after?: string, limit?: number }
+  ): Promise<OpenAI.Beta.Threads.Messages.MessagesPage>;
   getVectorStoreRelation(aiAssistantId: string): Promise<VectorStoreDocument>
   getVectorStoreRelationsByPageIds(pageId: Types.ObjectId[]): Promise<VectorStoreDocument[]>;
   createVectorStoreFile(vectorStoreRelation: VectorStoreDocument, pages: PageDocument[]): Promise<void>;
@@ -150,6 +155,22 @@ class OpenaiService implements IOpenaiService {
     await ThreadRelationModel.deleteMany({ threadId: { $in: deletedThreadIds } });
   }
 
+  async getMessageData(
+      threadId: string, lang?: Lang, options?: { limit: number, before: string, after: string },
+  ): Promise<OpenAI.Beta.Threads.Messages.MessagesPage> {
+    const messages = await this.client.getMessages(threadId, options);
+
+    for await (const message of messages.data) {
+      for await (const content of message.content) {
+        if (content.type === 'text') {
+          await replaceAnnotationWithPageLink(content, lang);
+        }
+      }
+    }
+
+    return messages;
+  }
+
   // TODO: https://redmine.weseek.co.jp/issues/160332
   // public async getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument> {
   //   const vectorStoreDocument: VectorStoreDocument | null = await VectorStoreModel.findOne({ scopeType: VectorStoreScopeType.PUBLIC, isDeleted: false });

+ 5 - 5
apps/app/src/features/openai/server/services/replace-annotation-with-page-link.ts

@@ -1,14 +1,14 @@
 // See: https://platform.openai.com/docs/assistants/tools/file-search#step-5-create-a-run-and-check-the-output
 
 import type { IPageHasId, Lang } from '@growi/core/dist/interfaces';
-import type { MessageContentDelta } from 'openai/resources/beta/threads/messages.mjs';
+import type { MessageContentDelta, MessageContent } from 'openai/resources/beta/threads/messages.mjs';
 
 import VectorStoreFileRelationModel from '~/features/openai/server/models/vector-store-file-relation';
 import { getTranslation } from '~/server/service/i18next';
 
-export const replaceAnnotationWithPageLink = async(messageContentDelta: MessageContentDelta, lang?: Lang): Promise<void> => {
-  if (messageContentDelta?.type === 'text' && messageContentDelta?.text?.annotations != null) {
-    const annotations = messageContentDelta?.text?.annotations;
+export const replaceAnnotationWithPageLink = async(messageContent: MessageContentDelta | MessageContent, lang?: Lang): Promise<void> => {
+  if (messageContent?.type === 'text' && messageContent?.text?.annotations != null) {
+    const annotations = messageContent?.text?.annotations;
     for await (const annotation of annotations) {
       if (annotation.type === 'file_citation' && annotation.text != null) {
 
@@ -18,7 +18,7 @@ export const replaceAnnotationWithPageLink = async(messageContentDelta: MessageC
 
         if (vectorStoreFileRelation != null) {
           const { t } = await getTranslation({ lang });
-          messageContentDelta.text.value = messageContentDelta.text.value?.replace(
+          messageContent.text.value = messageContent.text.value?.replace(
             annotation.text,
             ` [${t('source')}: [${vectorStoreFileRelation.page.path}](/${vectorStoreFileRelation.page._id})]`,
           );