Shun Miyazawa 1 год назад
Родитель
Сommit
d482677a15

+ 9 - 1
apps/app/src/features/openai/server/routes/thread.ts

@@ -11,6 +11,7 @@ import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
 import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
 import loggerFactory from '~/utils/logger';
 
+import AiAssistantModel from '../models/ai-assistant';
 import { getOpenaiService } from '../services/openai';
 
 import { certifyAiService } from './middlewares/certify-ai-service';
@@ -53,10 +54,17 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
         }
 
+        const aiAssistant = await AiAssistantModel.findById(aiAssistantId);
+        if (aiAssistant == null) {
+          return res.apiv3Err(new ErrorV3('AI assistant not found'), 404);
+        }
+
+        const additionalInstruction = aiAssistant.additionalInstruction;
+
         const filteredThreadId = threadId != null ? filterXSS(threadId) : undefined;
         const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
 
-        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId, initialUserMessage);
+        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId, initialUserMessage, additionalInstruction);
         return res.apiv3(thread);
       }
       catch (err) {

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -20,8 +20,11 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     // TODO: initialize openaiVectorStoreId property
   }
 
-  async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
+  async createThread(vectorStoreId: string, additionalInstruction?: string): Promise<OpenAI.Beta.Threads.Thread> {
     return this.client.beta.threads.create({
+      messages: additionalInstruction != null
+        ? [{ role: 'assistant', content: additionalInstruction }]
+        : [],
       tool_resources: {
         file_search: {
           vector_store_ids: [vectorStoreId],

+ 1 - 1
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -2,7 +2,7 @@ import type OpenAI from 'openai';
 import type { Uploadable } from 'openai/uploads';
 
 export interface IOpenaiClientDelegator {
-  createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
+  createThread(vectorStoreId: string, additionalInstruction?: string): Promise<OpenAI.Beta.Threads.Thread>
   retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread>
   deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted>
   getMessages(threadId: string, options?: { limit: number, before: string, after: string }): Promise<OpenAI.Beta.Threads.Messages.MessagesPage>

+ 4 - 1
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -23,8 +23,11 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     this.client = new OpenAI({ apiKey });
   }
 
-  async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
+  async createThread(vectorStoreId: string, additionalInstruction?: string): Promise<OpenAI.Beta.Threads.Thread> {
     return this.client.beta.threads.create({
+      messages: additionalInstruction != null
+        ? [{ role: 'assistant', content: additionalInstruction }]
+        : [],
       tool_resources: {
         file_search: {
           vector_store_ids: [vectorStoreId],

+ 7 - 3
apps/app/src/features/openai/server/services/openai.ts

@@ -63,7 +63,11 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 
 export interface IOpenaiService {
   getOrCreateThread(
-    userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string
+    userId: string,
+    vectorStoreRelation: VectorStoreDocument,
+    threadId?: string,
+    initialUserMessage?: string,
+    additionalInstruction?: string,
   ): Promise<ThreadRelationDocument>;
   getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
@@ -122,7 +126,7 @@ class OpenaiService implements IOpenaiService {
   }
 
   async getOrCreateThread(
-      userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string,
+      userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string, additionalInstruction?: string,
   ): Promise<ThreadRelationDocument> {
     if (threadId == null) {
       let threadTitle: string | null = null;
@@ -136,7 +140,7 @@ class OpenaiService implements IOpenaiService {
       }
 
       try {
-        const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
+        const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId, additionalInstruction);
         const threadRelation = await ThreadRelationModel.create({
           userId,
           threadId: thread.id,