thread.ts 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import type { IUserHasId } from '@growi/core/dist/interfaces';
  2. import type { Request, RequestHandler } from 'express';
  3. import type { ValidationChain } from 'express-validator';
  4. import { body } from 'express-validator';
  5. import { filterXSS } from 'xss';
  6. import type Crowi from '~/server/crowi';
  7. import { accessTokenParser } from '~/server/middlewares/access-token-parser';
  8. import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
  9. import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
  10. import loggerFactory from '~/utils/logger';
  11. import { getOpenaiService } from '../services/openai';
  12. import { certifyAiService } from './middlewares/certify-ai-service';
  13. const logger = loggerFactory('growi:routes:apiv3:openai:thread');
  14. type CreateThreadReq = Request<undefined, ApiV3Response, { threadId?: string }> & { user: IUserHasId };
  15. type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
  16. export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
  17. const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
  18. const validator: ValidationChain[] = [
  19. body('threadId').optional().isString().withMessage('threadId must be string'),
  20. ];
  21. return [
  22. accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
  23. async(req: CreateThreadReq, res: ApiV3Response) => {
  24. try {
  25. const openaiService = getOpenaiService();
  26. const filterdThreadId = req.body.threadId != null ? filterXSS(req.body.threadId) : undefined;
  27. const vectorStore = await openaiService?.getOrCreateVectorStoreForPublicScope();
  28. const thread = await openaiService?.getOrCreateThread(req.user._id, vectorStore?.vectorStoreId, filterdThreadId);
  29. return res.apiv3({ thread });
  30. }
  31. catch (err) {
  32. logger.error(err);
  33. return res.apiv3Err(err);
  34. }
  35. },
  36. ];
  37. };