Просмотр исходного кода

Merge pull request #10345 from growilabs/fix/156800-csrf-protection-origin

fix: CSRF protection by origin comparison
mergify[bot] 5 месяцев назад
Родитель
Сommit
77e126de3e

+ 6 - 3
apps/app/src/server/crowi/express-init.js

@@ -2,9 +2,11 @@ import { themesRootPath as presetThemesRootPath } from '@growi/preset-themes';
 import csrf from 'csurf';
 import csrf from 'csurf';
 import qs from 'qs';
 import qs from 'qs';
 
 
-import { PLUGIN_EXPRESS_STATIC_DIR, PLUGIN_STORING_PATH } from '~/features/growi-plugin/server/consts';
+
+import { PLUGIN_EXPRESS_STATIC_DIR, PLUGIN_STORING_PATH } from '../../features/growi-plugin/server/consts';
+import loggerFactory from '../../utils/logger';
 import { resolveFromRoot } from '~/server/util/project-dir-utils';
 import { resolveFromRoot } from '~/server/util/project-dir-utils';
-import loggerFactory from '~/utils/logger';
+import CertifyOrigin from '../middlewares/certify-origin';
 
 
 import registerSafeRedirectFactory from '../middlewares/safe-redirect';
 import registerSafeRedirectFactory from '../middlewares/safe-redirect';
 
 
@@ -26,7 +28,6 @@ module.exports = function(crowi, app) {
   const registerSafeRedirect = registerSafeRedirectFactory();
   const registerSafeRedirect = registerSafeRedirectFactory();
   const injectCurrentuserToLocalvars = require('../middlewares/inject-currentuser-to-localvars')();
   const injectCurrentuserToLocalvars = require('../middlewares/inject-currentuser-to-localvars')();
   const autoReconnectToS2sMsgServer = require('../middlewares/auto-reconnect-to-s2s-msg-server')(crowi);
   const autoReconnectToS2sMsgServer = require('../middlewares/auto-reconnect-to-s2s-msg-server')(crowi);
-
   const avoidSessionRoutes = require('../routes/avoid-session-routes');
   const avoidSessionRoutes = require('../routes/avoid-session-routes');
 
 
   const env = crowi.node_env;
   const env = crowi.node_env;
@@ -123,6 +124,8 @@ module.exports = function(crowi, app) {
   // default methods + PUT. See: https://expressjs.com/en/resources/middleware/csurf.html#ignoremethods
   // default methods + PUT. See: https://expressjs.com/en/resources/middleware/csurf.html#ignoremethods
   app.use(csrf({ ignoreMethods: ['GET', 'HEAD', 'OPTIONS', 'PUT', 'POST', 'DELETE'], cookie: false }));
   app.use(csrf({ ignoreMethods: ['GET', 'HEAD', 'OPTIONS', 'PUT', 'POST', 'DELETE'], cookie: false }));
 
 
+  app.use('/_api', CertifyOrigin);
+
   // passport
   // passport
   logger.debug('initialize Passport');
   logger.debug('initialize Passport');
   app.use(passport.initialize());
   app.use(passport.initialize());

+ 42 - 0
apps/app/src/server/middlewares/certify-origin.ts

@@ -0,0 +1,42 @@
+import { ErrorV3 } from '@growi/core/dist/models';
+import type { NextFunction, Response } from 'express';
+
+import loggerFactory from '../../utils/logger';
+import { configManager } from '../service/config-manager';
+import isSimpleRequest from '../util/is-simple-request';
+
+import type { AccessTokenParserReq } from './access-token-parser/interfaces';
+
+
+const logger = loggerFactory('growi:middleware:certify-origin');
+
+type Apiv3ErrFunction = (error: ErrorV3) => void;
+
+const certifyOrigin = (req: AccessTokenParserReq, res: Response & { apiv3Err: Apiv3ErrFunction }, next: NextFunction): void => {
+
+  const appSiteUrl = configManager.getConfig('app:siteUrl');
+  const configuredOrigin = appSiteUrl ? new URL(appSiteUrl).origin : null;
+  const requestOrigin = req.headers.origin;
+  const runtimeOrigin = `${req.protocol}://${req.get('host')}`;
+
+  const isSameOriginReq = requestOrigin == null
+  || requestOrigin === configuredOrigin
+  || requestOrigin === runtimeOrigin;
+
+  const accessToken = req.query.access_token ?? req.body.access_token;
+
+  if (!isSameOriginReq && req.headers.origin != null && isSimpleRequest(req)) {
+    const message = 'Invalid request (origin check failed but simple request)';
+    logger.error(message);
+    return res.apiv3Err(new ErrorV3(message));
+  }
+
+  if (!isSameOriginReq && accessToken == null && !isSimpleRequest(req)) {
+    const message = 'Invalid request (origin check failed and no access token)';
+    logger.error(message);
+    return res.apiv3Err(new ErrorV3(message));
+  }
+
+  next();
+};
+export default certifyOrigin;

+ 2 - 1
apps/app/src/server/routes/index.js

@@ -49,6 +49,7 @@ module.exports = function(crowi, app) {
   const tag = require('./tag')(crowi, app);
   const tag = require('./tag')(crowi, app);
   const search = require('./search')(crowi, app);
   const search = require('./search')(crowi, app);
   const ogp = require('./ogp')(crowi);
   const ogp = require('./ogp')(crowi);
+  const { createApiRouter } = require('~/server/util/createApiRouter');
 
 
   const next = nextFactory(crowi);
   const next = nextFactory(crowi);
 
 
@@ -121,7 +122,7 @@ module.exports = function(crowi, app) {
   // API v3
   // API v3
   app.use('/_api/v3', unavailableWhenMaintenanceModeForApi, apiV3Router);
   app.use('/_api/v3', unavailableWhenMaintenanceModeForApi, apiV3Router);
 
 
-  const apiV1Router = express.Router();
+  const apiV1Router = createApiRouter();
 
 
   apiV1Router.get('/search'              , accessTokenParser([SCOPE.READ.FEATURES.PAGE], { acceptLegacy: true }) , loginRequired , search.api.search);
   apiV1Router.get('/search'              , accessTokenParser([SCOPE.READ.FEATURES.PAGE], { acceptLegacy: true }) , loginRequired , search.api.search);
 
 

+ 12 - 0
apps/app/src/server/util/createApiRouter.ts

@@ -0,0 +1,12 @@
+import express, { type Router } from 'express';
+import CertifyOrigin from '~/server/middlewares/certify-origin';
+
+function createApiRouter(): Router {
+  const router = express.Router();
+  router.use(CertifyOrigin);
+  return router;
+}
+
+export {
+  createApiRouter,
+};

+ 200 - 0
apps/app/src/server/util/is-simple-request.spec.ts

@@ -0,0 +1,200 @@
+import type { Request } from 'express';
+import { mock } from 'vitest-mock-extended';
+
+import isSimpleRequest from './is-simple-request';
+
+describe('isSimpleRequest', () => {
+
+
+  // method
+  describe('When request method is checked', () => {
+
+    // allow
+    describe('When allowed method is given', () => {
+      const allowedMethods = ['GET', 'HEAD', 'POST'];
+      it.each(allowedMethods)('returns true for %s method', (method) => {
+        const reqMock = mock<Request>();
+        reqMock.method = method;
+        reqMock.headers = { 'content-type': 'text/plain' };
+        expect(isSimpleRequest(reqMock)).toBe(true);
+      });
+    });
+
+    // disallow
+    describe('When disallowed method is given', () => {
+      const disallowedMethods = ['PUT', 'DELETE', 'PATCH', 'OPTIONS', 'TRACE'];
+      it.each(disallowedMethods)('returns false for %s method', (method) => {
+        const reqMock = mock<Request>();
+        reqMock.method = method;
+        reqMock.headers = {};
+        expect(isSimpleRequest(reqMock)).toBe(false);
+      });
+    });
+
+  });
+
+
+  // headers
+  describe('When request headers are checked', () => {
+
+    // allow(Other than content-type)
+    describe('When only safe headers are given', () => {
+      const safeHeaders = [
+        'accept',
+        'accept-language',
+        'content-language',
+        'range',
+        'referer',
+        'dpr',
+        'downlink',
+        'save-data',
+        'viewport-width',
+        'width',
+      ];
+      it.each(safeHeaders)('returns true for safe header: %s', (headerName) => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = {
+          [headerName]: 'test-value',
+        };
+        expect(isSimpleRequest(reqMock)).toBe(true);
+      });
+      // content-type
+      it('returns true for valid content-type values', () => {
+        const validContentTypes = [
+          'application/x-www-form-urlencoded',
+          'multipart/form-data',
+          'text/plain',
+        ];
+        validContentTypes.forEach((contentType) => {
+          const reqMock = mock<Request>();
+          reqMock.method = 'POST';
+          reqMock.headers = { 'content-type': contentType };
+          expect(isSimpleRequest(reqMock)).toBe(true);
+        });
+      });
+      // combination
+      it('returns true for combination of safe headers', () => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = {
+          Accept: 'application/json',
+          'content-Type': 'text/plain',
+          'Accept-Language': 'en-US',
+        };
+        expect(isSimpleRequest(reqMock)).toBe(true);
+      });
+    });
+
+    // disallow
+    describe('When unsafe headers are given', () => {
+      const unsafeHeaders = [
+        'X-Custom-Header',
+        'Authorization',
+        'X-Requested-With',
+        'X-CSRF-Token',
+      ];
+      it.each(unsafeHeaders)('returns false for unsafe header: %s', (headerName) => {
+        const reqMock = mock<Request>({
+          method: 'POST',
+          headers: { [headerName]: 'test-value' },
+        });
+        expect(isSimpleRequest(reqMock)).toBe(false);
+      });
+      // combination
+      it('returns false when safe and unsafe headers are mixed', () => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = {
+          Accept: 'application/json', // Safe
+          'X-Custom-Header': 'custom-value', // Unsafe
+        };
+        expect(isSimpleRequest(reqMock)).toBe(false);
+      });
+    });
+
+  });
+
+
+  // content-type
+  describe('When content-type is checked', () => {
+
+    // allow
+    describe('When a safe content-type is given', () => {
+      const safeContentTypes = [
+        'application/x-www-form-urlencoded',
+        'multipart/form-data',
+        'text/plain',
+        // parameters
+        'application/x-www-form-urlencoded; charset=UTF-8',
+        'multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW',
+        'text/plain; charset=iso-8859-1',
+      ];
+      it.each(safeContentTypes)('returns true for %s', (contentType) => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = {
+          'content-type': contentType,
+        };
+        expect(isSimpleRequest(reqMock)).toBe(true);
+      });
+    });
+    // absent
+    describe('When content-type is absent', () => {
+      it('returns true when no content-type header is set', () => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = {};
+        expect(isSimpleRequest(reqMock)).toBe(true);
+      });
+    });
+
+    // disallow
+    describe('When disallowed content-type is given', () => {
+      const disallowedContentTypes = [
+        'application/json',
+        'application/xml',
+        'text/html',
+        'application/octet-stream',
+      ];
+      it.each(disallowedContentTypes)('returns false for %s', (contentType) => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = { 'content-type': contentType };
+        expect(isSimpleRequest(reqMock)).toBe(false);
+      });
+    });
+
+  });
+
+  // integration
+  describe('When multiple conditions are checked', () => {
+
+    describe('When all conditions are met', () => {
+      it('returns true', () => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = { 'content-type': 'application/x-www-form-urlencoded' };
+        expect(isSimpleRequest(reqMock)).toBe(true);
+      });
+    });
+
+    describe('When method is disallowed but headers are safe', () => {
+      it('returns false', () => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'PUT';
+        reqMock.headers = { 'content-type': 'text/plain' };
+        expect(isSimpleRequest(reqMock)).toBe(false);
+      });
+    });
+
+    describe('When method is allowed but headers are non-safe', () => {
+      it('returns false', () => {
+        const reqMock = mock<Request>();
+        reqMock.method = 'POST';
+        reqMock.headers = { 'X-Custom-Header': 'custom-value' };
+        expect(isSimpleRequest(reqMock)).toBe(false);
+      });
+    });
+  });
+});

+ 65 - 0
apps/app/src/server/util/is-simple-request.ts

@@ -0,0 +1,65 @@
+import type { Request } from 'express';
+
+import type { AccessTokenParserReq } from '~/server/middlewares/access-token-parser/interfaces';
+
+const allowedMethods = ['GET', 'HEAD', 'POST'] as const;
+type AllowedMethod = typeof allowedMethods[number];
+function isAllowedMethod(method: string): method is AllowedMethod {
+  return allowedMethods.includes(method as AllowedMethod);
+}
+
+const safeRequestHeaders = [
+  'accept',
+  'accept-language',
+  'content-language',
+  'content-type',
+  'range',
+  'referer',
+  'dpr',
+  'downlink',
+  'save-data',
+  'viewport-width',
+  'width',
+] as const;
+type SafeRequestHeader = typeof safeRequestHeaders[number];
+
+function isSafeRequestHeader(header: string): header is SafeRequestHeader {
+  return safeRequestHeaders.includes(header.toLowerCase() as SafeRequestHeader);
+}
+
+const allowedContentTypes = [
+  'application/x-www-form-urlencoded',
+  'multipart/form-data',
+  'text/plain',
+] as const;
+type AllowedContentType = typeof allowedContentTypes[number];
+
+function isAllowedContentType(contentType: string): contentType is AllowedContentType {
+  return allowedContentTypes.some(allowed => contentType.toLowerCase().startsWith(allowed));
+}
+
+const isSimpleRequest = (req: Request | AccessTokenParserReq): boolean => {
+  // 1. Check if the request method is allowed
+  if (!isAllowedMethod(req.method)) {
+    return false;
+  }
+
+  // 2. Check if the request headers are safe
+  const nonSafeHeaders = Object.keys(req.headers).filter((header) => {
+    return !isSafeRequestHeader(header);
+  });
+  if (nonSafeHeaders.length > 0) {
+    return false;
+  }
+
+  // 3. Content-Type is
+  const contentType = req.headers['content-type'];
+  if (contentType != null && !isAllowedContentType(contentType)) {
+    return false;
+  }
+
+  // Return true if all conditions are met
+  return true;
+};
+
+export default isSimpleRequest;