Răsfoiți Sursa

feat: enhance OpenAI client with streaming support and type guard

Shun Miyazawa 9 luni în urmă
părinte
comite
f7ec43205b

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -1,6 +1,7 @@
 import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity';
 import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity';
 import type OpenAI from 'openai';
 import type OpenAI from 'openai';
 import { AzureOpenAI } from 'openai';
 import { AzureOpenAI } from 'openai';
+import { type Stream } from 'openai/streaming';
 import { type Uploadable } from 'openai/uploads';
 import { type Uploadable } from 'openai/uploads';
 
 
 import type { MessageListParams } from '../../../interfaces/message';
 import type { MessageListParams } from '../../../interfaces/message';
@@ -94,7 +95,9 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
   }
 
 
-  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+  async chatCompletion(
+      body: OpenAI.Chat.Completions.ChatCompletionCreateParams,
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>> {
     return this.client.chat.completions.create(body);
     return this.client.chat.completions.create(body);
   }
   }
 
 

+ 1 - 0
apps/app/src/features/openai/server/services/client-delegator/index.ts

@@ -1 +1,2 @@
 export * from './get-client';
 export * from './get-client';
+export * from './is-stream-response';

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -1,4 +1,5 @@
 import type OpenAI from 'openai';
 import type OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
 import type { Uploadable } from 'openai/uploads';
 import type { Uploadable } from 'openai/uploads';
 
 
 import type { MessageListParams } from '../../../interfaces/message';
 import type { MessageListParams } from '../../../interfaces/message';
@@ -16,5 +17,7 @@ export interface IOpenaiClientDelegator {
   createVectorStoreFile(vectorStoreId: string, fileId: string): Promise<OpenAI.VectorStores.Files.VectorStoreFile>
   createVectorStoreFile(vectorStoreId: string, fileId: string): Promise<OpenAI.VectorStores.Files.VectorStoreFile>
   createVectorStoreFileBatch(vectorStoreId: string, fileIds: string[]): Promise<OpenAI.VectorStores.FileBatches.VectorStoreFileBatch>
   createVectorStoreFileBatch(vectorStoreId: string, fileIds: string[]): Promise<OpenAI.VectorStores.FileBatches.VectorStoreFileBatch>
   deleteFile(fileId: string): Promise<OpenAI.Files.FileDeleted>;
   deleteFile(fileId: string): Promise<OpenAI.Files.FileDeleted>;
-  chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion>
+  chatCompletion(
+    body: OpenAI.Chat.Completions.ChatCompletionCreateParams
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>
 }
 }

+ 12 - 0
apps/app/src/features/openai/server/services/client-delegator/is-stream-response.ts

@@ -0,0 +1,12 @@
+import type OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
+
+type ChatCompletionResponse = OpenAI.Chat.Completions.ChatCompletion;
+type ChatCompletionStreamResponse = Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
+
+// Type guard function
+export const isStreamResponse = (result: ChatCompletionResponse | ChatCompletionStreamResponse): result is ChatCompletionStreamResponse => {
+  // Type assertion is safe due to the constrained input types
+  const assertedResult = result as any;
+  return assertedResult.tee != null && assertedResult.toReadableStream != null;
+};

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -1,4 +1,5 @@
 import OpenAI from 'openai';
 import OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
 import { type Uploadable } from 'openai/uploads';
 import { type Uploadable } from 'openai/uploads';
 
 
 import { configManager } from '~/server/service/config-manager';
 import { configManager } from '~/server/service/config-manager';
@@ -95,7 +96,9 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
   }
 
 
-  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+  async chatCompletion(
+      body: OpenAI.Chat.Completions.ChatCompletionCreateParams,
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>> {
     return this.client.chat.completions.create(body);
     return this.client.chat.completions.create(body);
   }
   }
 
 

+ 7 - 5
apps/app/src/features/openai/server/services/openai.ts

@@ -4,7 +4,7 @@ import { Readable, Transform, Writable } from 'stream';
 import { pipeline } from 'stream/promises';
 import { pipeline } from 'stream/promises';
 
 
 import type {
 import type {
-  IUser, Ref, Lang, IPage,
+  IUser, Ref, Lang, IPage, Nullable,
 } from '@growi/core';
 } from '@growi/core';
 import {
 import {
   PageGrant, getIdForRef, getIdStringForRef, isPopulated, type IUserHasId,
   PageGrant, getIdForRef, getIdStringForRef, isPopulated, type IUserHasId,
@@ -45,7 +45,7 @@ import { convertMarkdownToHtml } from '../utils/convert-markdown-to-html';
 import { generateGlobPatterns } from '../utils/generate-glob-patterns';
 import { generateGlobPatterns } from '../utils/generate-glob-patterns';
 import { isVectorStoreCompatible } from '../utils/is-vector-store-compatible';
 import { isVectorStoreCompatible } from '../utils/is-vector-store-compatible';
 
 
-import { getClient } from './client-delegator';
+import { getClient, isStreamResponse } from './client-delegator';
 import { openaiApiErrorHandler } from './openai-api-error-handler';
 import { openaiApiErrorHandler } from './openai-api-error-handler';
 import { replaceAnnotationWithPageLink } from './replace-annotation-with-page-link';
 import { replaceAnnotationWithPageLink } from './replace-annotation-with-page-link';
 
 
@@ -108,7 +108,7 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
     return getClient({ openaiServiceType });
   }
   }
 
 
-  private async generateThreadTitle(message: string): Promise<string | null> {
+  private async generateThreadTitle(message: string): Promise<Nullable<string>> {
     const systemMessage = [
     const systemMessage = [
       'Create a brief title (max 5 words) from your message.',
       'Create a brief title (max 5 words) from your message.',
       'Respond in the same language the user uses in their input.',
       'Respond in the same language the user uses in their input.',
@@ -129,8 +129,10 @@ class OpenaiService implements IOpenaiService {
       ],
       ],
     });
     });
 
 
-    const threadTitle = threadTitleCompletion.choices[0].message.content;
-    return threadTitle;
+    if (!isStreamResponse(threadTitleCompletion)) {
+      const threadTitle = threadTitleCompletion.choices[0].message.content;
+      return threadTitle;
+    }
   }
   }
 
 
   async createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument> {
   async createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument> {