yusa-a 8 месяцев назад
Родитель
Сommit
391469ea61

+ 2 - 3
apps/app/src/server/crowi/express-init.js

@@ -3,7 +3,7 @@ import csrf from 'csurf';
 import qs from 'qs';
 
 import { PLUGIN_EXPRESS_STATIC_DIR, PLUGIN_STORING_PATH } from '~/features/growi-plugin/server/consts';
-import registerCertifyOrigin from '~/server/middlewares/certify-origin';
+import CertifyOrigin from '~/server/middlewares/certify-origin';
 import loggerFactory from '~/utils/logger';
 import { resolveFromRoot } from '~/utils/project-dir-utils';
 
@@ -27,7 +27,6 @@ module.exports = function(crowi, app) {
   const registerSafeRedirect = registerSafeRedirectFactory();
   const injectCurrentuserToLocalvars = require('../middlewares/inject-currentuser-to-localvars')();
   const autoReconnectToS2sMsgServer = require('../middlewares/auto-reconnect-to-s2s-msg-server')(crowi);
-  const certifyOrigin = registerCertifyOrigin(crowi);
   const avoidSessionRoutes = require('../routes/avoid-session-routes');
 
   const env = crowi.node_env;
@@ -124,7 +123,7 @@ module.exports = function(crowi, app) {
   // default methods + PUT. See: https://expressjs.com/en/resources/middleware/csurf.html#ignoremethods
   app.use(csrf({ ignoreMethods: ['GET', 'HEAD', 'OPTIONS', 'PUT', 'POST', 'DELETE'], cookie: false }));
 
-  app.use(certifyOrigin);
+  app.use(CertifyOrigin);//
 
   // passport
   logger.debug('initialize Passport');

+ 18 - 22
apps/app/src/server/middlewares/certify-origin.ts

@@ -10,29 +10,25 @@ const logger = loggerFactory('growi:middleware:certify-origin');
 
 type Apiv3ErrFunction = (error: ErrorV3) => void;
 
-const certifyOrigin = (): ((req: AccessTokenParserReq, res: Response & { apiv3Err: Apiv3ErrFunction }, next: NextFunction) => void) => {
+const certifyOrigin = (req: AccessTokenParserReq, res: Response & { apiv3Err: Apiv3ErrFunction }, next: NextFunction): void => {
 
   const appSiteUrl = configManager.getConfig('app:siteUrl');
-  return (req: AccessTokenParserReq, res: Response & { apiv3Err }, next: NextFunction): void => {
-
-    const isSameOriginReq = req.headers.origin == null || req.headers.origin === appSiteUrl;
-    req.isSameOriginReq = isSameOriginReq;
-    const accessToken = req.query.access_token ?? req.body.access_token;
-    req.isSimpleRequest = isSimpleRequest(req);
-
-    if (!isSameOriginReq && req.headers.origin != null && req.isSimpleRequest) {
-      const message = 'Invalid request (origin check failed but simple request)';
-      logger.error(message);
-      return res.apiv3Err(new ErrorV3(message));
-    }
-
-    if (!isSameOriginReq && accessToken == null && !req.isSimpleRequest) {
-      const message = 'Invalid request (origin check failed and no access token)';
-      logger.error(message);
-      return res.apiv3Err(new ErrorV3(message));
-    }
-
-    next();
-  };
+
+  const isSameOriginReq = req.headers.origin == null || req.headers.origin === appSiteUrl;
+  const accessToken = req.query.access_token ?? req.body.access_token;
+
+  if (!isSameOriginReq && req.headers.origin != null && isSimpleRequest(req)) {
+    const message = 'Invalid request (origin check failed but simple request)';
+    logger.error(message);
+    return res.apiv3Err(new ErrorV3(message));
+  }
+
+  if (!isSameOriginReq && accessToken == null && !isSimpleRequest(req)) {
+    const message = 'Invalid request (origin check failed and no access token)';
+    logger.error(message);
+    return res.apiv3Err(new ErrorV3(message));
+  }
+
+  next();
 };
 export default certifyOrigin;

+ 13 - 7
apps/app/src/server/util/is-simple-request.ts

@@ -2,14 +2,12 @@ import type { Request } from 'express';
 
 import type { AccessTokenParserReq } from '~/server/middlewares/access-token-parser/interfaces';
 
-// 1. Check if the request method is allowed
 const allowedMethods = ['GET', 'HEAD', 'POST'] as const;
 type AllowedMethod = typeof allowedMethods[number];
 function isAllowedMethod(method: string): method is AllowedMethod {
   return allowedMethods.includes(method as AllowedMethod);
 }
 
-// 2. Check if the request headers are safe
 const safeRequestHeaders = [
   'accept',
   'accept-language',
@@ -25,7 +23,10 @@ const safeRequestHeaders = [
 ] as const;
 type SafeRequestHeader = typeof safeRequestHeaders[number];
 
-// 3. Content-Type is
+function isSafeRequestHeader(header: string): header is SafeRequestHeader {
+  return safeRequestHeaders.includes(header.toLowerCase() as SafeRequestHeader);
+}
+
 const allowedContentTypes = [
   'application/x-www-form-urlencoded',
   'multipart/form-data',
@@ -33,22 +34,27 @@ const allowedContentTypes = [
 ] as const;
 type AllowedContentType = typeof allowedContentTypes[number];
 
+function isAllowedContentType(contentType: string): contentType is AllowedContentType {
+  return allowedContentTypes.some(allowed => contentType.toLowerCase().startsWith(allowed));
+}
+
 const isSimpleRequest = (req: Request | AccessTokenParserReq): boolean => {
+  // 1. Check if the request method is allowed
   if (!isAllowedMethod(req.method)) {
     return false;
   }
 
+  // 2. Check if the request headers are safe
   const nonSafeHeaders = Object.keys(req.headers).filter((header) => {
-    const headerLower = header.toLowerCase();
-    return !safeRequestHeaders.includes(headerLower as SafeRequestHeader);
+    return !isSafeRequestHeader(header);
   });
-
   if (nonSafeHeaders.length > 0) {
     return false;
   }
 
+  // 3. Content-Type is
   const contentType = req.headers['content-type'];
-  if (contentType != null && !allowedContentTypes.includes(contentType.toLowerCase() as AllowedContentType)) {
+  if (contentType != null && !isAllowedContentType(contentType)) {
     return false;
   }