thread.ts 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import type { IUserHasId } from '@growi/core/dist/interfaces';
  2. import { ErrorV3 } from '@growi/core/dist/models';
  3. import type { Request, RequestHandler } from 'express';
  4. import type { ValidationChain } from 'express-validator';
  5. import { body } from 'express-validator';
  6. import { filterXSS } from 'xss';
  7. import type Crowi from '~/server/crowi';
  8. import { accessTokenParser } from '~/server/middlewares/access-token-parser';
  9. import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
  10. import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
  11. import loggerFactory from '~/utils/logger';
  12. import { getOpenaiService } from '../services/openai';
  13. import { certifyAiService } from './middlewares/certify-ai-service';
  14. const logger = loggerFactory('growi:routes:apiv3:openai:thread');
  15. type ReqBody = {
  16. aiAssistantId: string,
  17. threadId?: string,
  18. }
  19. type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
  20. type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
  21. export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
  22. const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
  23. const validator: ValidationChain[] = [
  24. body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
  25. body('threadId').optional().isString().withMessage('threadId must be string'),
  26. ];
  27. return [
  28. accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
  29. async(req: CreateThreadReq, res: ApiV3Response) => {
  30. const openaiService = getOpenaiService();
  31. if (openaiService == null) {
  32. return res.apiv3Err(new ErrorV3('GROWI AI is not enabled'), 501);
  33. }
  34. try {
  35. const { aiAssistantId, threadId } = req.body;
  36. // リクエストした user が AiAssistant の owner or shareScope に含まれているかチェックする
  37. const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
  38. const filteredThreadId = threadId != null ? filterXSS(threadId) : undefined;
  39. const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId);
  40. return res.apiv3({ thread });
  41. }
  42. catch (err) {
  43. logger.error(err);
  44. return res.apiv3Err(err);
  45. }
  46. },
  47. ];
  48. };