Просмотр исходного кода

feat: update pre-message handling to include finished state and refactor related methods

Shun Miyazawa 9 месяцев назад
Родитель
Сommit
4a44965ba4

+ 2 - 1
apps/app/src/features/openai/interfaces/knowledge-assistant/sse-schemas.ts

@@ -12,7 +12,8 @@ export const SseMessageSchema = z.object({
 });
 
 export const SsePreMessageSchema = z.object({
-  text: z.string().describe('The pre-message that should be appended to the chat window'),
+  text: z.string().nullish().describe('The pre-message that should be appended to the chat window'),
+  finished: z.boolean().describe('Indicates if the pre-message generation is finished'),
 });
 
 

+ 7 - 11
apps/app/src/features/openai/server/routes/message/post-message.ts

@@ -124,22 +124,18 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
         'Cache-Control': 'no-cache, no-transform',
       });
 
-      let isMainMessageGenerating = false;
+      const preMessageChunkHandler = (chunk: ChatCompletionChunk) => {
+        const chunkChoice = chunk.choices[0];
 
-      const preMessageDeltaHandler = (delta: ChatCompletionChunk.Choice.Delta) => {
-        if (isMainMessageGenerating) {
-          return;
-        }
+        const content = {
+          text: chunkChoice.delta.content,
+          finished: chunkChoice.finish_reason != null,
+        };
 
-        const content = { text: delta.content };
         res.write(`data: ${JSON.stringify(content)}\n\n`);
       };
 
       const messageDeltaHandler = async(delta: MessageDelta) => {
-        if (!isMainMessageGenerating) {
-          isMainMessageGenerating = true;
-        }
-
         const content = delta.content?.[0];
 
         // If annotation is found
@@ -155,7 +151,7 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
       };
 
       // Don't add await since SSE is performed asynchronously with main message
-      openaiService.generateAndProcessPreMessage(req.body.userMessage, preMessageDeltaHandler);
+      openaiService.generateAndProcessPreMessage(req.body.userMessage, preMessageChunkHandler);
 
       stream.on('event', (delta) => {
         if (delta.event === 'thread.run.failed') {

+ 3 - 4
apps/app/src/features/openai/server/services/openai.ts

@@ -73,7 +73,7 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 
 export interface IOpenaiService {
-  generateAndProcessPreMessage(message: string, deltaProcessor: (delta: ChatCompletionChunk.Choice.Delta) => void): Promise<void>
+  generateAndProcessPreMessage(message: string, chunkProcessor: (chunk: ChatCompletionChunk) => void): Promise<void>
   createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument>;
   getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
@@ -110,7 +110,7 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
-  async generateAndProcessPreMessage(message: string, deltaProcessor: (delta: ChatCompletionChunk.Choice.Delta) => void): Promise<void> {
+  async generateAndProcessPreMessage(message: string, chunkProcessor: (delta: ChatCompletionChunk) => void): Promise<void> {
     const systemMessage = [
       "Generate a message briefly confirming the user's question.",
       'Please generate up to 20 characters',
@@ -136,8 +136,7 @@ class OpenaiService implements IOpenaiService {
     }
 
     for await (const chunk of preMessageCompletion) {
-      const delta = chunk.choices[0].delta;
-      deltaProcessor(delta);
+      chunkProcessor(chunk);
     }
   }