Explorar o código

Merge pull request #10089 from weseek/feat/pre-response-message

feat(ai): Send pre-message before main chat stream
Shun Miyazawa hai 9 meses
pai
achega
e2c19c508b

+ 30 - 5
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantSidebar/AiAssistantSidebar.tsx

@@ -241,7 +241,10 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
 
         const chunk = decoder.decode(value);
 
-        const textValues: string[] = [];
+        let isPreMessageGenerated = false;
+        let isMainMessageGenerationStarted = false;
+        const preMessages: string[] = [];
+        const mainMessages: string[] = [];
         const lines = chunk.split('\n\n');
         lines.forEach((line) => {
           const trimmedLine = line.trim();
@@ -249,14 +252,37 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
             const data = JSON.parse(line.replace('data: ', ''));
 
             processMessageForKnowledgeAssistant(data, {
+              onPreMessage: (data) => {
+                // When main message is sent while pre-message is being transmitted
+                if (isMainMessageGenerationStarted) {
+                  preMessages.length = 0;
+                  return;
+                }
+                if (data.finished) {
+                  isPreMessageGenerated = true;
+                  return;
+                }
+                if (data.text == null) {
+                  return;
+                }
+                preMessages.push(data.text);
+              },
               onMessage: (data) => {
-                textValues.push(data.content[0].text.value);
+                if (!isMainMessageGenerationStarted) {
+                  isMainMessageGenerationStarted = true;
+                }
+
+                // When main message is sent while pre-message is being transmitted
+                if (!isPreMessageGenerated) {
+                  preMessages.length = 0;
+                }
+                mainMessages.push(data.content[0].text.value);
               },
             });
 
             processMessageForEditorAssistant(data, {
               onMessage: (data) => {
-                textValues.push(data.appendedMessage);
+                mainMessages.push(data.appendedMessage);
               },
               onDetectedDiff: (data) => {
                 logger.debug('sse diff', { data });
@@ -277,13 +303,12 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
           }
         });
 
-
         // append text values to the assistant message
         setGeneratingAnswerMessage((prevMessage) => {
           if (prevMessage == null) return;
           return {
             ...prevMessage,
-            content: prevMessage.content + textValues.join(''),
+            content: prevMessage.content + preMessages.join('') + mainMessages.join(''),
           };
         });
 

+ 10 - 2
apps/app/src/features/openai/client/services/knowledge-assistant.tsx

@@ -10,7 +10,9 @@ import {
 } from 'reactstrap';
 
 import { apiv3Post } from '~/client/util/apiv3-client';
-import { SseMessageSchema, type SseMessage } from '~/features/openai/interfaces/knowledge-assistant/sse-schemas';
+import {
+  SseMessageSchema, type SseMessage, SsePreMessageSchema, type SsePreMessage,
+} from '~/features/openai/interfaces/knowledge-assistant/sse-schemas';
 import { handleIfSuccessfullyParsed } from '~/features/openai/utils/handle-if-successfully-parsed';
 
 import type { MessageLog, MessageWithCustomMetaData } from '../../interfaces/message';
@@ -31,7 +33,9 @@ interface PostMessage {
 
 interface ProcessMessage {
   (data: unknown, handler: {
-    onMessage: (data: SseMessage) => void}
+    onMessage: (data: SseMessage) => void
+    onPreMessage: (data: SsePreMessage) => void
+  }
   ): void;
 }
 
@@ -121,6 +125,10 @@ export const useKnowledgeAssistant: UseKnowledgeAssistant = () => {
     handleIfSuccessfullyParsed(data, SseMessageSchema, (data: SseMessage) => {
       handler.onMessage(data);
     });
+
+    handleIfSuccessfullyParsed(data, SsePreMessageSchema, (data: SsePreMessage) => {
+      handler.onPreMessage(data);
+    });
   }, []);
 
   // Views

+ 6 - 0
apps/app/src/features/openai/interfaces/knowledge-assistant/sse-schemas.ts

@@ -11,6 +11,12 @@ export const SseMessageSchema = z.object({
   })),
 });
 
+export const SsePreMessageSchema = z.object({
+  text: z.string().nullish().describe('The pre-message that should be appended to the chat window'),
+  finished: z.boolean().describe('Indicates if the pre-message generation is finished'),
+});
+
 
 // Type definitions
 export type SseMessage = z.infer<typeof SseMessageSchema>;
+export type SsePreMessage = z.infer<typeof SsePreMessageSchema>;

+ 21 - 0
apps/app/src/features/openai/server/routes/message/post-message.ts

@@ -5,6 +5,7 @@ import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
 import type { AssistantStream } from 'openai/lib/AssistantStream';
 import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
+import { type ChatCompletionChunk } from 'openai/resources/chat/completions';
 
 import { getOrCreateChatAssistant } from '~/features/openai/server/services/assistant';
 import type Crowi from '~/server/crowi';
@@ -115,11 +116,25 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
         return res.status(500).send(err.message);
       }
 
+      /**
+      * Create SSE (Server-Sent Events) Responses
+      */
       res.writeHead(200, {
         'Content-Type': 'text/event-stream;charset=utf-8',
         'Cache-Control': 'no-cache, no-transform',
       });
 
+      const preMessageChunkHandler = (chunk: ChatCompletionChunk) => {
+        const chunkChoice = chunk.choices[0];
+
+        const content = {
+          text: chunkChoice.delta.content,
+          finished: chunkChoice.finish_reason != null,
+        };
+
+        res.write(`data: ${JSON.stringify(content)}\n\n`);
+      };
+
       const messageDeltaHandler = async(delta: MessageDelta) => {
         const content = delta.content?.[0];
 
@@ -135,6 +150,12 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
         res.write(`error: ${JSON.stringify({ code, message })}\n\n`);
       };
 
+      // Don't add await since SSE is performed asynchronously with main message
+      openaiService.generateAndProcessPreMessage(req.body.userMessage, preMessageChunkHandler)
+        .catch((err) => {
+          logger.error(err);
+        });
+
       stream.on('event', (delta) => {
         if (delta.event === 'thread.run.failed') {
           const errorMessage = delta.data.last_error?.message;

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -1,6 +1,7 @@
 import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity';
 import type OpenAI from 'openai';
 import { AzureOpenAI } from 'openai';
+import { type Stream } from 'openai/streaming';
 import { type Uploadable } from 'openai/uploads';
 
 import type { MessageListParams } from '../../../interfaces/message';
@@ -94,7 +95,9 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
 
-  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+  async chatCompletion(
+      body: OpenAI.Chat.Completions.ChatCompletionCreateParams,
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>> {
     return this.client.chat.completions.create(body);
   }
 

+ 1 - 0
apps/app/src/features/openai/server/services/client-delegator/index.ts

@@ -1 +1,2 @@
 export * from './get-client';
+export * from './is-stream-response';

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -1,4 +1,5 @@
 import type OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
 import type { Uploadable } from 'openai/uploads';
 
 import type { MessageListParams } from '../../../interfaces/message';
@@ -16,5 +17,7 @@ export interface IOpenaiClientDelegator {
   createVectorStoreFile(vectorStoreId: string, fileId: string): Promise<OpenAI.VectorStores.Files.VectorStoreFile>
   createVectorStoreFileBatch(vectorStoreId: string, fileIds: string[]): Promise<OpenAI.VectorStores.FileBatches.VectorStoreFileBatch>
   deleteFile(fileId: string): Promise<OpenAI.Files.FileDeleted>;
-  chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion>
+  chatCompletion(
+    body: OpenAI.Chat.Completions.ChatCompletionCreateParams
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>
 }

+ 12 - 0
apps/app/src/features/openai/server/services/client-delegator/is-stream-response.ts

@@ -0,0 +1,12 @@
+import type OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
+
+type ChatCompletionResponse = OpenAI.Chat.Completions.ChatCompletion;
+type ChatCompletionStreamResponse = Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
+
+// Type guard function
+export const isStreamResponse = (result: ChatCompletionResponse | ChatCompletionStreamResponse): result is ChatCompletionStreamResponse => {
+  // Type assertion is safe due to the constrained input types
+  const assertedResult = result as any;
+  return assertedResult.tee != null && assertedResult.toReadableStream != null;
+};

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -1,4 +1,5 @@
 import OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
 import { type Uploadable } from 'openai/uploads';
 
 import { configManager } from '~/server/service/config-manager';
@@ -95,7 +96,9 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
 
-  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+  async chatCompletion(
+      body: OpenAI.Chat.Completions.ChatCompletionCreateParams,
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>> {
     return this.client.chat.completions.create(body);
   }
 

+ 39 - 5
apps/app/src/features/openai/server/services/openai.ts

@@ -4,7 +4,7 @@ import { Readable, Transform, Writable } from 'stream';
 import { pipeline } from 'stream/promises';
 
 import type {
-  IUser, Ref, Lang, IPage,
+  IUser, Ref, Lang, IPage, Nullable,
 } from '@growi/core';
 import {
   PageGrant, getIdForRef, getIdStringForRef, isPopulated, type IUserHasId,
@@ -15,6 +15,7 @@ import escapeStringRegexp from 'escape-string-regexp';
 import createError from 'http-errors';
 import mongoose, { type HydratedDocument, type Types } from 'mongoose';
 import { type OpenAI, toFile } from 'openai';
+import { type ChatCompletionChunk } from 'openai/resources/chat/completions';
 
 import ExternalUserGroupRelation from '~/features/external-user-group/server/models/external-user-group-relation';
 import ThreadRelationModel, { type ThreadRelationDocument } from '~/features/openai/server/models/thread-relation';
@@ -45,7 +46,7 @@ import { convertMarkdownToHtml } from '../utils/convert-markdown-to-html';
 import { generateGlobPatterns } from '../utils/generate-glob-patterns';
 import { isVectorStoreCompatible } from '../utils/is-vector-store-compatible';
 
-import { getClient } from './client-delegator';
+import { getClient, isStreamResponse } from './client-delegator';
 import { openaiApiErrorHandler } from './openai-api-error-handler';
 import { replaceAnnotationWithPageLink } from './replace-annotation-with-page-link';
 
@@ -72,6 +73,7 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 
 export interface IOpenaiService {
+  generateAndProcessPreMessage(message: string, chunkProcessor: (chunk: ChatCompletionChunk) => void): Promise<void>
   createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument>;
   getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
@@ -108,7 +110,37 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
-  private async generateThreadTitle(message: string): Promise<string | null> {
+  async generateAndProcessPreMessage(message: string, chunkProcessor: (delta: ChatCompletionChunk) => void): Promise<void> {
+    const systemMessage = [
+      "Generate a message briefly confirming the user's question.",
+      'Please generate up to 20 characters',
+    ].join('');
+
+    const preMessageCompletion = await this.client.chatCompletion({
+      stream: true,
+      model: 'gpt-4.1-nano',
+      messages: [
+        {
+          role: 'system',
+          content: systemMessage,
+        },
+        {
+          role: 'user',
+          content: message,
+        },
+      ],
+    });
+
+    if (!isStreamResponse(preMessageCompletion)) {
+      return;
+    }
+
+    for await (const chunk of preMessageCompletion) {
+      chunkProcessor(chunk);
+    }
+  }
+
+  private async generateThreadTitle(message: string): Promise<Nullable<string>> {
     const systemMessage = [
       'Create a brief title (max 5 words) from your message.',
       'Respond in the same language the user uses in their input.',
@@ -129,8 +161,10 @@ class OpenaiService implements IOpenaiService {
       ],
     });
 
-    const threadTitle = threadTitleCompletion.choices[0].message.content;
-    return threadTitle;
+    if (!isStreamResponse(threadTitleCompletion)) {
+      const threadTitle = threadTitleCompletion.choices[0].message.content;
+      return threadTitle;
+    }
   }
 
   async createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument> {