Shun Miyazawa 1 год назад
Родитель
Сommit
a91a192a2e

+ 5 - 8
apps/app/src/features/rate-limiter/middleware/consume-points.integ.ts

@@ -1,6 +1,6 @@
 import { faker } from '@faker-js/faker';
 
-const testRateLimitErrorWhenExceedingMaxRequests = async(method: string, endpoint: string, key: string, maxRequests: number): Promise<void> => {
+const testRateLimitErrorWhenExceedingMaxRequests = async(method: string, key: string, maxRequests: number): Promise<void> => {
   // dynamic import is used because rateLimiterMongo needs to be initialized after connecting to DB
   // Issue: https://github.com/animir/node-rate-limiter-flexible/issues/216
   const { consumePoints } = await import('./consume-points');
@@ -9,7 +9,7 @@ const testRateLimitErrorWhenExceedingMaxRequests = async(method: string, endpoin
     for (let i = 1; i <= maxRequests + 1; i++) {
       count += 1;
       // eslint-disable-next-line no-await-in-loop
-      const res = await consumePoints(method, endpoint, key, { method, maxRequests });
+      const res = await consumePoints(method, key, { method, maxRequests });
       if (count === maxRequests) {
         // Expect consumedPoints to be equal to maxRequest when maxRequest is reached
         expect(res?.consumedPoints).toBe(maxRequests);
@@ -34,30 +34,27 @@ describe('consume-points.ts', async() => {
   it('Should trigger a rate limit error when maxRequest is exceeded (maxRequest: 1)', async() => {
     // setup
     const method = 'GET';
-    const endpoint = '/_api/v3/test-1';
     const key = 'test-key-1';
     const maxRequests = 1;
 
-    await testRateLimitErrorWhenExceedingMaxRequests(method, endpoint, key, maxRequests);
+    await testRateLimitErrorWhenExceedingMaxRequests(method, key, maxRequests);
   });
 
   it('Should trigger a rate limit error when maxRequest is exceeded (maxRequest: 500)', async() => {
     // setup
     const method = 'GET';
-    const endpoint = '/_api/v3/test-2';
     const key = 'test-key-2';
     const maxRequests = 500;
 
-    await testRateLimitErrorWhenExceedingMaxRequests(method, endpoint, key, maxRequests);
+    await testRateLimitErrorWhenExceedingMaxRequests(method, key, maxRequests);
   });
 
   it('Should trigger a rate limit error when maxRequest is exceeded (maxRequest: {random integer between 1 and 1000})', async() => {
     // setup
     const method = 'GET';
-    const endpoint = '/_api/v3/test-3';
     const key = 'test-key-3';
     const maxRequests = faker.number.int({ min: 1, max: 1000 });
 
-    await testRateLimitErrorWhenExceedingMaxRequests(method, endpoint, key, maxRequests);
+    await testRateLimitErrorWhenExceedingMaxRequests(method, key, maxRequests);
   });
 });

+ 2 - 2
apps/app/src/features/rate-limiter/middleware/consume-points.ts

@@ -5,7 +5,7 @@ import { DEFAULT_MAX_REQUESTS, type IApiRateLimitConfig } from '../config';
 import { rateLimiterFactory } from './rate-limiter-factory';
 
 export const consumePoints = async(
-    method: string, endpoint: string, key: string | null, customizedConfig?: IApiRateLimitConfig, maxRequestsMultiplier?: number,
+    method: string, key: string | null, customizedConfig?: IApiRateLimitConfig, maxRequestsMultiplier?: number,
 ): Promise<RateLimiterRes | undefined> => {
   if (key == null) {
     return;
@@ -23,7 +23,7 @@ export const consumePoints = async(
     maxRequests *= maxRequestsMultiplier;
   }
 
-  const rateLimiter = rateLimiterFactory.getOrCreateRateLimiter(endpoint, maxRequests);
+  const rateLimiter = rateLimiterFactory.getOrCreateRateLimiter(key, maxRequests);
 
   const pointsToConsume = 1;
   const rateLimiterRes = await rateLimiter.consume(key, pointsToConsume);

+ 6 - 6
apps/app/src/features/rate-limiter/middleware/factory.ts

@@ -34,9 +34,9 @@ const valuesWithRegExp = Object.values(configWithRegExp);
  * @returns
  */
 const consumePointsByUser = async(
-    method: string, endpoint: string, key: string | null, customizedConfig?: IApiRateLimitConfig,
+    method: string, key: string | null, customizedConfig?: IApiRateLimitConfig,
 ): Promise<RateLimiterRes | undefined> => {
-  return consumePoints(method, endpoint, key, customizedConfig);
+  return consumePoints(method, key, customizedConfig);
 };
 
 /**
@@ -47,10 +47,10 @@ const consumePointsByUser = async(
  * @returns
  */
 const consumePointsByIp = async(
-    method: string, endpoint: string, key: string | null, customizedConfig?: IApiRateLimitConfig,
+    method: string, key: string | null, customizedConfig?: IApiRateLimitConfig,
 ): Promise<RateLimiterRes | undefined> => {
   const maxRequestsMultiplier = customizedConfig?.usersPerIpProspection ?? DEFAULT_USERS_PER_IP_PROSPECTION;
-  return consumePoints(method, endpoint, key, customizedConfig, maxRequestsMultiplier);
+  return consumePoints(method, key, customizedConfig, maxRequestsMultiplier);
 };
 
 
@@ -83,7 +83,7 @@ export const middlewareFactory = (): Handler => {
     // check for the current user
     if (req.user != null) {
       try {
-        await consumePointsByUser(req.method, endpoint, keyForUser, customizedConfig);
+        await consumePointsByUser(req.method, keyForUser, customizedConfig);
       }
       catch {
         logger.error(`${req.user._id}: too many request at ${endpoint}`);
@@ -93,7 +93,7 @@ export const middlewareFactory = (): Handler => {
 
     // check for ip
     try {
-      await consumePointsByIp(req.method, endpoint, keyForIp, customizedConfig);
+      await consumePointsByIp(req.method, keyForIp, customizedConfig);
     }
     catch {
       logger.error(`${req.ip}: too many request at ${endpoint}`);

+ 1 - 7
apps/app/src/features/rate-limiter/middleware/rate-limiter-factory.ts

@@ -7,13 +7,7 @@ class RateLimiterFactory {
 
   private rateLimiters: Map<string, RateLimiterMongo> = new Map();
 
-  private generateKey(endpoint: string): string {
-    return `rate_limiter_${endpoint}`;
-  }
-
-  getOrCreateRateLimiter(endpoint: string, maxRequests: number): RateLimiterMongo {
-    const key = this.generateKey(endpoint);
-
+  getOrCreateRateLimiter(key: string, maxRequests: number): RateLimiterMongo {
     const cachedRateLimiter = this.rateLimiters.get(key);
     if (cachedRateLimiter != null) {
       return cachedRateLimiter;