|
|
@@ -4,6 +4,7 @@ import { body } from 'express-validator';
|
|
|
|
|
|
import type Crowi from '~/server/crowi';
|
|
|
import { openaiClient } from '~/server/service/openai';
|
|
|
+import { getOrCreateSearchAssistant } from '~/server/service/openai/assistant';
|
|
|
import loggerFactory from '~/utils/logger';
|
|
|
|
|
|
import { apiV3FormValidator } from '../../../middlewares/apiv3-form-validator';
|
|
|
@@ -13,6 +14,7 @@ const logger = loggerFactory('growi:routes:apiv3:openai:chat');
|
|
|
|
|
|
type ReqBody = {
|
|
|
userMessage: string,
|
|
|
+ threadId?: string,
|
|
|
}
|
|
|
|
|
|
type Req = Request<undefined, ApiV3Response, ReqBody>
|
|
|
@@ -25,18 +27,41 @@ export const chatHandlersFactory: ChatHandlersFactory = (crowi) => {
|
|
|
|
|
|
const validator: ValidationChain[] = [
|
|
|
body('userMessage').isString().withMessage('userMessage must be string'),
|
|
|
+ body('threadId').optional().isString().withMessage('threadId must be string'),
|
|
|
];
|
|
|
|
|
|
return [
|
|
|
accessTokenParser, loginRequiredStrictly, validator, apiV3FormValidator,
|
|
|
async(req: Req, res: ApiV3Response) => {
|
|
|
+ const assistantId = process.env.OPENAI_ASSISTANT_ID;
|
|
|
+ const vectorStoreId = process.env.OPENAI_VECTOR_STORE_ID;
|
|
|
+ if (assistantId == null || vectorStoreId == null) {
|
|
|
+ return res.apiv3Err('OPENAI_ASSISTANT_ID or OPENAI_VECTOR_STORE_ID is not setup', 503);
|
|
|
+ }
|
|
|
+
|
|
|
try {
|
|
|
- const chatCompletion = await openaiClient.chat.completions.create({
|
|
|
- messages: [{ role: 'assistant', content: req.body.userMessage }],
|
|
|
- model: 'gpt-4o',
|
|
|
- });
|
|
|
+ await getOrCreateSearchAssistant();
|
|
|
+
|
|
|
+ const threadId = req.body.threadId;
|
|
|
+ const thread = threadId == null
|
|
|
+ ? await openaiClient.beta.threads.create({
|
|
|
+ messages: [{ role: 'assistant', content: req.body.userMessage }],
|
|
|
+ tool_resources: {
|
|
|
+ file_search: {
|
|
|
+ vector_store_ids: [vectorStoreId],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ })
|
|
|
+ : await openaiClient.beta.threads.retrieve(threadId);
|
|
|
+
|
|
|
+ const run = await openaiClient.beta.threads.runs.createAndPoll(thread.id, { assistant_id: assistantId });
|
|
|
+
|
|
|
+ if (run.status === 'completed') {
|
|
|
+ const messages = await openaiClient.beta.threads.messages.list(run.thread_id);
|
|
|
+ return res.apiv3({ messages });
|
|
|
+ }
|
|
|
|
|
|
- return res.apiv3({ assistantMessage: chatCompletion.choices[0].message.content });
|
|
|
+ return res.apiv3({});
|
|
|
}
|
|
|
catch (err) {
|
|
|
logger.error(err);
|