Pārlūkot izejas kodu

feat: update generatePreMessage to generateAndProcessPreMessage for enhanced streaming support

Shun Miyazawa 10 mēneši atpakaļ
vecāks
revīzija
6541c0c751

+ 16 - 6
apps/app/src/features/openai/server/services/openai.ts

@@ -3,7 +3,6 @@ import assert from 'node:assert';
 import { Readable, Transform, Writable } from 'stream';
 import { Readable, Transform, Writable } from 'stream';
 import { pipeline } from 'stream/promises';
 import { pipeline } from 'stream/promises';
 
 
-
 import type {
 import type {
   IUser, Ref, Lang, IPage, Nullable,
   IUser, Ref, Lang, IPage, Nullable,
 } from '@growi/core';
 } from '@growi/core';
@@ -16,6 +15,7 @@ import escapeStringRegexp from 'escape-string-regexp';
 import createError from 'http-errors';
 import createError from 'http-errors';
 import mongoose, { type HydratedDocument, type Types } from 'mongoose';
 import mongoose, { type HydratedDocument, type Types } from 'mongoose';
 import { type OpenAI, toFile } from 'openai';
 import { type OpenAI, toFile } from 'openai';
+import { type ChatCompletionChunk } from 'openai/resources/chat/completions';
 import { type Stream } from 'openai/streaming';
 import { type Stream } from 'openai/streaming';
 
 
 import ExternalUserGroupRelation from '~/features/external-user-group/server/models/external-user-group-relation';
 import ExternalUserGroupRelation from '~/features/external-user-group/server/models/external-user-group-relation';
@@ -74,7 +74,10 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 };
 
 
 export interface IOpenaiService {
 export interface IOpenaiService {
-  generatePreMessage(message: string): Promise<Nullable<Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>>;
+  generateAndProcessPreMessage(
+      message: string,
+      deltaProcessor: (delta: ChatCompletionChunk.Choice.Delta) => void,
+  ): Promise<Nullable<Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>>
   createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument>;
   createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument>;
   getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
@@ -111,8 +114,10 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
     return getClient({ openaiServiceType });
   }
   }
 
 
-
-  async generatePreMessage(message: string): Promise<Nullable<Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>> {
+  async generateAndProcessPreMessage(
+      message: string,
+      deltaProcessor: (delta: ChatCompletionChunk.Choice.Delta) => void,
+  ): Promise<Nullable<Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>> {
     const systemMessage = [
     const systemMessage = [
       "Generate a message briefly confirming the user's question.",
       "Generate a message briefly confirming the user's question.",
       'Please generate up to 20 characters',
       'Please generate up to 20 characters',
@@ -133,8 +138,13 @@ class OpenaiService implements IOpenaiService {
       ],
       ],
     });
     });
 
 
-    if (isStreamResponse(preMessageCompletion)) {
-      return preMessageCompletion;
+    if (!isStreamResponse(preMessageCompletion)) {
+      return;
+    }
+
+    for await (const chunk of preMessageCompletion) {
+      const delta = chunk.choices[0].delta;
+      deltaProcessor(delta);
     }
     }
   }
   }