Просмотр исходного кода

Merge pull request #10086 from weseek/feat/167463-implement-pre-message-creation-logic

feat(ai): Send pre-message to client via SSE
Yuki Takei 9 месяцев назад
Родитель
Сommit
0a4dde9a30

+ 3 - 0
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantSidebar/AiAssistantSidebar.tsx

@@ -249,6 +249,9 @@ const AiAssistantSidebarSubstance: React.FC<AiAssistantSidebarSubstanceProps> =
             const data = JSON.parse(line.replace('data: ', ''));
             const data = JSON.parse(line.replace('data: ', ''));
 
 
             processMessageForKnowledgeAssistant(data, {
             processMessageForKnowledgeAssistant(data, {
+              onPreMessage: (data) => {
+                textValues.push(data.text);
+              },
               onMessage: (data) => {
               onMessage: (data) => {
                 textValues.push(data.content[0].text.value);
                 textValues.push(data.content[0].text.value);
               },
               },

+ 10 - 2
apps/app/src/features/openai/client/services/knowledge-assistant.tsx

@@ -10,7 +10,9 @@ import {
 } from 'reactstrap';
 } from 'reactstrap';
 
 
 import { apiv3Post } from '~/client/util/apiv3-client';
 import { apiv3Post } from '~/client/util/apiv3-client';
-import { SseMessageSchema, type SseMessage } from '~/features/openai/interfaces/knowledge-assistant/sse-schemas';
+import {
+  SseMessageSchema, type SseMessage, SsePreMessageSchema, type SsePreMessage,
+} from '~/features/openai/interfaces/knowledge-assistant/sse-schemas';
 import { handleIfSuccessfullyParsed } from '~/features/openai/utils/handle-if-successfully-parsed';
 import { handleIfSuccessfullyParsed } from '~/features/openai/utils/handle-if-successfully-parsed';
 
 
 import type { MessageLog, MessageWithCustomMetaData } from '../../interfaces/message';
 import type { MessageLog, MessageWithCustomMetaData } from '../../interfaces/message';
@@ -31,7 +33,9 @@ interface PostMessage {
 
 
 interface ProcessMessage {
 interface ProcessMessage {
   (data: unknown, handler: {
   (data: unknown, handler: {
-    onMessage: (data: SseMessage) => void}
+    onMessage: (data: SseMessage) => void
+    onPreMessage: (data: SsePreMessage) => void
+  }
   ): void;
   ): void;
 }
 }
 
 
@@ -121,6 +125,10 @@ export const useKnowledgeAssistant: UseKnowledgeAssistant = () => {
     handleIfSuccessfullyParsed(data, SseMessageSchema, (data: SseMessage) => {
     handleIfSuccessfullyParsed(data, SseMessageSchema, (data: SseMessage) => {
       handler.onMessage(data);
       handler.onMessage(data);
     });
     });
+
+    handleIfSuccessfullyParsed(data, SsePreMessageSchema, (data: SsePreMessage) => {
+      handler.onPreMessage(data);
+    });
   }, []);
   }, []);
 
 
   // Views
   // Views

+ 5 - 0
apps/app/src/features/openai/interfaces/knowledge-assistant/sse-schemas.ts

@@ -11,6 +11,11 @@ export const SseMessageSchema = z.object({
   })),
   })),
 });
 });
 
 
+export const SsePreMessageSchema = z.object({
+  text: z.string().describe('The pre-message that should be appended to the chat window'),
+});
+
 
 
 // Type definitions
 // Type definitions
 export type SseMessage = z.infer<typeof SseMessageSchema>;
 export type SseMessage = z.infer<typeof SseMessageSchema>;
+export type SsePreMessage = z.infer<typeof SsePreMessageSchema>;

+ 9 - 0
apps/app/src/features/openai/server/routes/message/post-message.ts

@@ -5,6 +5,7 @@ import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
 import { body } from 'express-validator';
 import type { AssistantStream } from 'openai/lib/AssistantStream';
 import type { AssistantStream } from 'openai/lib/AssistantStream';
 import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
 import type { MessageDelta } from 'openai/resources/beta/threads/messages.mjs';
+import { type ChatCompletionChunk } from 'openai/resources/chat/completions';
 
 
 import { getOrCreateChatAssistant } from '~/features/openai/server/services/assistant';
 import { getOrCreateChatAssistant } from '~/features/openai/server/services/assistant';
 import type Crowi from '~/server/crowi';
 import type Crowi from '~/server/crowi';
@@ -120,6 +121,14 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
         'Cache-Control': 'no-cache, no-transform',
         'Cache-Control': 'no-cache, no-transform',
       });
       });
 
 
+      const preMessageDeltaHandler = (delta: ChatCompletionChunk.Choice.Delta) => {
+        const content = { text: delta.content };
+        res.write(`data: ${JSON.stringify(content)}\n\n`);
+      };
+
+      // Don't add await since SSE is performed asynchronously with main message
+      openaiService.generateAndProcessPreMessage(req.body.userMessage, preMessageDeltaHandler);
+
       const messageDeltaHandler = async(delta: MessageDelta) => {
       const messageDeltaHandler = async(delta: MessageDelta) => {
         const content = delta.content?.[0];
         const content = delta.content?.[0];
 
 

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -1,6 +1,7 @@
 import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity';
 import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity';
 import type OpenAI from 'openai';
 import type OpenAI from 'openai';
 import { AzureOpenAI } from 'openai';
 import { AzureOpenAI } from 'openai';
+import { type Stream } from 'openai/streaming';
 import { type Uploadable } from 'openai/uploads';
 import { type Uploadable } from 'openai/uploads';
 
 
 import type { MessageListParams } from '../../../interfaces/message';
 import type { MessageListParams } from '../../../interfaces/message';
@@ -94,7 +95,9 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
   }
 
 
-  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+  async chatCompletion(
+      body: OpenAI.Chat.Completions.ChatCompletionCreateParams,
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>> {
     return this.client.chat.completions.create(body);
     return this.client.chat.completions.create(body);
   }
   }
 
 

+ 1 - 0
apps/app/src/features/openai/server/services/client-delegator/index.ts

@@ -1 +1,2 @@
 export * from './get-client';
 export * from './get-client';
+export * from './is-stream-response';

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -1,4 +1,5 @@
 import type OpenAI from 'openai';
 import type OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
 import type { Uploadable } from 'openai/uploads';
 import type { Uploadable } from 'openai/uploads';
 
 
 import type { MessageListParams } from '../../../interfaces/message';
 import type { MessageListParams } from '../../../interfaces/message';
@@ -16,5 +17,7 @@ export interface IOpenaiClientDelegator {
   createVectorStoreFile(vectorStoreId: string, fileId: string): Promise<OpenAI.VectorStores.Files.VectorStoreFile>
   createVectorStoreFile(vectorStoreId: string, fileId: string): Promise<OpenAI.VectorStores.Files.VectorStoreFile>
   createVectorStoreFileBatch(vectorStoreId: string, fileIds: string[]): Promise<OpenAI.VectorStores.FileBatches.VectorStoreFileBatch>
   createVectorStoreFileBatch(vectorStoreId: string, fileIds: string[]): Promise<OpenAI.VectorStores.FileBatches.VectorStoreFileBatch>
   deleteFile(fileId: string): Promise<OpenAI.Files.FileDeleted>;
   deleteFile(fileId: string): Promise<OpenAI.Files.FileDeleted>;
-  chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion>
+  chatCompletion(
+    body: OpenAI.Chat.Completions.ChatCompletionCreateParams
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>
 }
 }

+ 12 - 0
apps/app/src/features/openai/server/services/client-delegator/is-stream-response.ts

@@ -0,0 +1,12 @@
+import type OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
+
+type ChatCompletionResponse = OpenAI.Chat.Completions.ChatCompletion;
+type ChatCompletionStreamResponse = Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
+
+// Type guard function
+export const isStreamResponse = (result: ChatCompletionResponse | ChatCompletionStreamResponse): result is ChatCompletionStreamResponse => {
+  // Type assertion is safe due to the constrained input types
+  const assertedResult = result as any;
+  return assertedResult.tee != null && assertedResult.toReadableStream != null;
+};

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -1,4 +1,5 @@
 import OpenAI from 'openai';
 import OpenAI from 'openai';
+import { type Stream } from 'openai/streaming';
 import { type Uploadable } from 'openai/uploads';
 import { type Uploadable } from 'openai/uploads';
 
 
 import { configManager } from '~/server/service/config-manager';
 import { configManager } from '~/server/service/config-manager';
@@ -95,7 +96,9 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
     return this.client.vectorStores.fileBatches.uploadAndPoll(vectorStoreId, { files });
   }
   }
 
 
-  async chatCompletion(body: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming): Promise<OpenAI.Chat.Completions.ChatCompletion> {
+  async chatCompletion(
+      body: OpenAI.Chat.Completions.ChatCompletionCreateParams,
+  ): Promise<OpenAI.Chat.Completions.ChatCompletion | Stream<OpenAI.Chat.Completions.ChatCompletionChunk>> {
     return this.client.chat.completions.create(body);
     return this.client.chat.completions.create(body);
   }
   }
 
 

+ 47 - 5
apps/app/src/features/openai/server/services/openai.ts

@@ -4,7 +4,7 @@ import { Readable, Transform, Writable } from 'stream';
 import { pipeline } from 'stream/promises';
 import { pipeline } from 'stream/promises';
 
 
 import type {
 import type {
-  IUser, Ref, Lang, IPage,
+  IUser, Ref, Lang, IPage, Nullable,
 } from '@growi/core';
 } from '@growi/core';
 import {
 import {
   PageGrant, getIdForRef, getIdStringForRef, isPopulated, type IUserHasId,
   PageGrant, getIdForRef, getIdStringForRef, isPopulated, type IUserHasId,
@@ -15,6 +15,8 @@ import escapeStringRegexp from 'escape-string-regexp';
 import createError from 'http-errors';
 import createError from 'http-errors';
 import mongoose, { type HydratedDocument, type Types } from 'mongoose';
 import mongoose, { type HydratedDocument, type Types } from 'mongoose';
 import { type OpenAI, toFile } from 'openai';
 import { type OpenAI, toFile } from 'openai';
+import { type ChatCompletionChunk } from 'openai/resources/chat/completions';
+import { type Stream } from 'openai/streaming';
 
 
 import ExternalUserGroupRelation from '~/features/external-user-group/server/models/external-user-group-relation';
 import ExternalUserGroupRelation from '~/features/external-user-group/server/models/external-user-group-relation';
 import ThreadRelationModel, { type ThreadRelationDocument } from '~/features/openai/server/models/thread-relation';
 import ThreadRelationModel, { type ThreadRelationDocument } from '~/features/openai/server/models/thread-relation';
@@ -45,7 +47,7 @@ import { convertMarkdownToHtml } from '../utils/convert-markdown-to-html';
 import { generateGlobPatterns } from '../utils/generate-glob-patterns';
 import { generateGlobPatterns } from '../utils/generate-glob-patterns';
 import { isVectorStoreCompatible } from '../utils/is-vector-store-compatible';
 import { isVectorStoreCompatible } from '../utils/is-vector-store-compatible';
 
 
-import { getClient } from './client-delegator';
+import { getClient, isStreamResponse } from './client-delegator';
 import { openaiApiErrorHandler } from './openai-api-error-handler';
 import { openaiApiErrorHandler } from './openai-api-error-handler';
 import { replaceAnnotationWithPageLink } from './replace-annotation-with-page-link';
 import { replaceAnnotationWithPageLink } from './replace-annotation-with-page-link';
 
 
@@ -72,6 +74,10 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 };
 
 
 export interface IOpenaiService {
 export interface IOpenaiService {
+  generateAndProcessPreMessage(
+      message: string,
+      deltaProcessor: (delta: ChatCompletionChunk.Choice.Delta) => void,
+  ): Promise<Nullable<Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>>
   createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument>;
   createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument>;
   getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
@@ -108,7 +114,41 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
     return getClient({ openaiServiceType });
   }
   }
 
 
-  private async generateThreadTitle(message: string): Promise<string | null> {
+  async generateAndProcessPreMessage(
+      message: string,
+      deltaProcessor: (delta: ChatCompletionChunk.Choice.Delta) => void,
+  ): Promise<Nullable<Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>> {
+    const systemMessage = [
+      "Generate a message briefly confirming the user's question.",
+      'Please generate up to 20 characters',
+    ].join('');
+
+    const preMessageCompletion = await this.client.chatCompletion({
+      stream: true,
+      model: 'gpt-4.1-nano',
+      messages: [
+        {
+          role: 'system',
+          content: systemMessage,
+        },
+        {
+          role: 'user',
+          content: message,
+        },
+      ],
+    });
+
+    if (!isStreamResponse(preMessageCompletion)) {
+      return;
+    }
+
+    for await (const chunk of preMessageCompletion) {
+      const delta = chunk.choices[0].delta;
+      deltaProcessor(delta);
+    }
+  }
+
+  private async generateThreadTitle(message: string): Promise<Nullable<string>> {
     const systemMessage = [
     const systemMessage = [
       'Create a brief title (max 5 words) from your message.',
       'Create a brief title (max 5 words) from your message.',
       'Respond in the same language the user uses in their input.',
       'Respond in the same language the user uses in their input.',
@@ -129,8 +169,10 @@ class OpenaiService implements IOpenaiService {
       ],
       ],
     });
     });
 
 
-    const threadTitle = threadTitleCompletion.choices[0].message.content;
-    return threadTitle;
+    if (!isStreamResponse(threadTitleCompletion)) {
+      const threadTitle = threadTitleCompletion.choices[0].message.content;
+      return threadTitle;
+    }
   }
   }
 
 
   async createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument> {
   async createThread(userId: string, type: ThreadType, aiAssistantId?: string, initialUserMessage?: string): Promise<ThreadRelationDocument> {