|
|
@@ -1,3 +1,4 @@
|
|
|
+import { getIdStringForRef } from '@growi/core';
|
|
|
import type { IUserHasId } from '@growi/core/dist/interfaces';
|
|
|
import { ErrorV3 } from '@growi/core/dist/models';
|
|
|
import type { Request, RequestHandler, Response } from 'express';
|
|
|
@@ -17,6 +18,7 @@ import loggerFactory from '~/utils/logger';
|
|
|
import { LlmEditorAssistantDiffSchema, LlmEditorAssistantMessageSchema } from '../../../interfaces/editor-assistant/llm-response-schemas';
|
|
|
import type { SseDetectedDiff, SseFinalized, SseMessage } from '../../../interfaces/editor-assistant/sse-schemas';
|
|
|
import { MessageErrorCode } from '../../../interfaces/message-error';
|
|
|
+import ThreadRelationModel from '../../models/thread-relation';
|
|
|
import { getOrCreateEditorAssistant } from '../../services/assistant';
|
|
|
import { openaiClient } from '../../services/client';
|
|
|
import { LlmResponseStreamProcessor } from '../../services/editor-assistant';
|
|
|
@@ -41,7 +43,6 @@ const LlmEditorAssistantResponseSchema = z.object({
|
|
|
type ReqBody = {
|
|
|
userMessage: string,
|
|
|
markdown: string,
|
|
|
- aiAssistantId?: string,
|
|
|
threadId?: string,
|
|
|
}
|
|
|
|
|
|
@@ -74,14 +75,15 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
|
|
|
.withMessage('markdown must be string')
|
|
|
.notEmpty()
|
|
|
.withMessage('markdown must be set'),
|
|
|
- body('aiAssistantId').optional().isMongoId().withMessage('aiAssistantId must be string'),
|
|
|
body('threadId').optional().isString().withMessage('threadId must be string'),
|
|
|
];
|
|
|
|
|
|
return [
|
|
|
accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
|
|
|
async(req: Req, res: ApiV3Response) => {
|
|
|
- const { userMessage, markdown, threadId } = req.body;
|
|
|
+ const {
|
|
|
+ userMessage, markdown, threadId,
|
|
|
+ } = req.body;
|
|
|
|
|
|
// Parameter check
|
|
|
if (threadId == null) {
|
|
|
@@ -94,6 +96,20 @@ export const postMessageToEditHandlersFactory: PostMessageHandlersFactory = (cro
|
|
|
return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
|
|
|
}
|
|
|
|
|
|
+ const threadRelation = await ThreadRelationModel.findOne({ threadId });
|
|
|
+ if (threadRelation == null) {
|
|
|
+ return res.apiv3Err(new ErrorV3('ThreadRelation not found'), 404);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check if usable
|
|
|
+ if (threadRelation.aiAssistant != null) {
|
|
|
+ const aiAssistantId = getIdStringForRef(threadRelation.aiAssistant);
|
|
|
+ const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
|
|
|
+ if (!isAiAssistantUsable) {
|
|
|
+ return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// Initialize SSE helper and stream processor
|
|
|
const sseHelper = new SseHelper(res);
|
|
|
const streamProcessor = new LlmResponseStreamProcessor({
|