فهرست منبع

feat: implement socket guarding mechanism to prevent premature closure during async operations

Yuki Takei 1 هفته پیش
والد
کامیت
6b9285ce65

+ 159 - 0
apps/app/src/server/service/yjs/guard-socket.spec.ts

@@ -0,0 +1,159 @@
+import http from 'node:http';
+import WebSocket, { WebSocketServer } from 'ws';
+import { docs, setPersistence, setupWSConnection } from 'y-websocket/bin/utils';
+
+import { guardSocket } from './guard-socket';
+
+/**
+ * Creates a test server where:
+ * 1. The Yjs upgrade handler guards the socket and awaits before completing
+ * 2. A hostile handler (simulating Next.js) calls socket.end() for /yjs/ paths
+ */
+const createServerWithHostileHandler = (): {
+  server: http.Server;
+  wss: WebSocketServer;
+} => {
+  const server = http.createServer();
+  const wss = new WebSocketServer({ noServer: true });
+
+  // Yjs handler (registered first — same order as production)
+  server.on('upgrade', async (request, socket, head) => {
+    const url = request.url ?? '';
+    if (!url.startsWith('/yjs/')) return;
+
+    const pageId = url.slice('/yjs/'.length).split('?')[0];
+
+    const guard = guardSocket(socket);
+
+    try {
+      // Simulate async auth delay
+      await new Promise((resolve) => setTimeout(resolve, 10));
+
+      guard.restore();
+
+      wss.handleUpgrade(request, socket, head, (ws) => {
+        wss.emit('connection', ws, request);
+        setupWSConnection(ws, request, { docName: pageId });
+      });
+    } catch {
+      guard.restore();
+      socket.destroy();
+    }
+  });
+
+  // Hostile handler (registered second — simulates Next.js upgradeHandler)
+  server.on('upgrade', (_request, socket) => {
+    socket.end();
+  });
+
+  return { server, wss };
+};
+
+const connectClient = (port: number, pageId: string): Promise<WebSocket> => {
+  return new Promise((resolve, reject) => {
+    const ws = new WebSocket(`ws://127.0.0.1:${port}/yjs/${pageId}`);
+    ws.binaryType = 'arraybuffer';
+    ws.on('open', () => resolve(ws));
+    ws.on('error', reject);
+  });
+};
+
+describe('guardSocket — protection against hostile upgrade handlers', () => {
+  let server: http.Server;
+  let wss: WebSocketServer;
+  let port: number;
+
+  beforeAll(async () => {
+    setPersistence(null);
+
+    const testServer = createServerWithHostileHandler();
+    server = testServer.server;
+    wss = testServer.wss;
+
+    await new Promise<void>((resolve) => {
+      server.listen(0, '127.0.0.1', () => {
+        const addr = server.address();
+        if (addr && typeof addr === 'object') {
+          port = addr.port;
+        }
+        resolve();
+      });
+    });
+  });
+
+  afterAll(async () => {
+    for (const [name, doc] of docs) {
+      doc.destroy();
+      docs.delete(name);
+    }
+
+    await new Promise<void>((resolve) => {
+      wss.close(() => {
+        server.close(() => resolve());
+      });
+    });
+  });
+
+  afterEach(() => {
+    for (const [name, doc] of docs) {
+      doc.destroy();
+      docs.delete(name);
+    }
+  });
+
+  it('should establish WebSocket connection even when a hostile handler calls socket.end()', async () => {
+    const pageId = 'guard-test-001';
+
+    const ws = await connectClient(port, pageId);
+
+    await new Promise((resolve) => setTimeout(resolve, 50));
+
+    const serverDoc = docs.get(pageId);
+    expect(serverDoc).toBeDefined();
+    assert(serverDoc !== undefined);
+    expect(serverDoc.conns.size).toBe(1);
+
+    ws.close();
+  });
+
+  it('should handle multiple concurrent connections with hostile handler', async () => {
+    const pageId = 'guard-test-002';
+
+    const connections = await Promise.all([
+      connectClient(port, pageId),
+      connectClient(port, pageId),
+    ]);
+
+    await new Promise((resolve) => setTimeout(resolve, 50));
+
+    const serverDoc = docs.get(pageId);
+    expect(serverDoc).toBeDefined();
+    assert(serverDoc !== undefined);
+    expect(serverDoc.conns.size).toBe(2);
+
+    for (const ws of connections) {
+      ws.close();
+    }
+  });
+
+  it('should properly restore socket methods after guard', async () => {
+    const pageId = 'guard-test-003';
+
+    const ws = await connectClient(port, pageId);
+
+    await new Promise((resolve) => setTimeout(resolve, 50));
+
+    // Connection succeeds, meaning socket.end/destroy were properly
+    // guarded during async auth and restored before wss.handleUpgrade
+    expect(ws.readyState).toBe(WebSocket.OPEN);
+
+    ws.close();
+    await new Promise((resolve) => setTimeout(resolve, 50));
+
+    // After close, the server doc should have removed the connection
+    const serverDoc = docs.get(pageId);
+    if (serverDoc) {
+      expect(serverDoc.conns.size).toBe(0);
+    }
+  });
+});

+ 30 - 0
apps/app/src/server/service/yjs/guard-socket.ts

@@ -0,0 +1,30 @@
+import type { Duplex } from 'node:stream';
+
+type SocketGuard = {
+  restore: () => void;
+};
+
+/**
+ * Temporarily replaces socket.end() and socket.destroy() with no-ops.
+ *
+ * This prevents other synchronous `upgrade` event listeners (e.g. Next.js's
+ * NextCustomServer.upgradeHandler) from closing the socket while an async
+ * handler is awaiting authentication.
+ *
+ * Call `restore()` on the returned object to reinstate the original methods
+ * before performing the actual WebSocket handshake or cleanup.
+ */
+export const guardSocket = (socket: Duplex): SocketGuard => {
+  const origEnd = socket.end.bind(socket);
+  const origDestroy = socket.destroy.bind(socket);
+
+  socket.end = () => socket;
+  socket.destroy = () => socket;
+
+  return {
+    restore: () => {
+      socket.end = origEnd;
+      socket.destroy = origDestroy;
+    },
+  };
+};

+ 2 - 2
apps/app/src/server/service/yjs/upgrade-handler.ts

@@ -49,10 +49,10 @@ const runMiddleware = (
  * Extracts pageId from upgrade request URL.
  * Extracts pageId from upgrade request URL.
  * Expected format: /yjs/{pageId}
  * Expected format: /yjs/{pageId}
  */
  */
+const pageIdPattern = new RegExp(`^${YJS_WEBSOCKET_BASE_PATH}/([a-f0-9]{24})`);
 const extractPageId = (url: string | undefined): string | null => {
 const extractPageId = (url: string | undefined): string | null => {
   if (url == null) return null;
   if (url == null) return null;
-  const pattern = new RegExp(`^${YJS_WEBSOCKET_BASE_PATH}/([a-f0-9]{24})`);
-  const match = url.match(pattern);
+  const match = url.match(pageIdPattern);
   return match?.[1] ?? null;
   return match?.[1] ?? null;
 };
 };
 
 

+ 32 - 9
apps/app/src/server/service/yjs/yjs.ts

@@ -15,6 +15,7 @@ import { normalizeLatestRevisionIfBroken } from '../revision/normalize-latest-re
 import { createIndexes } from './create-indexes';
 import { createIndexes } from './create-indexes';
 import { createMongoDBPersistence } from './create-mongodb-persistence';
 import { createMongoDBPersistence } from './create-mongodb-persistence';
 import { MongodbPersistence } from './extended/mongodb-persistence';
 import { MongodbPersistence } from './extended/mongodb-persistence';
+import { guardSocket } from './guard-socket';
 import { syncYDoc } from './sync-ydoc';
 import { syncYDoc } from './sync-ydoc';
 import { createUpgradeHandler } from './upgrade-handler';
 import { createUpgradeHandler } from './upgrade-handler';
 
 
@@ -75,16 +76,38 @@ class YjsService implements IYjsService {
         return;
         return;
       }
       }
 
 
-      const result = await handleUpgrade(request, socket, head);
-
-      if (!result.authorized) {
-        return;
+      // Guard the socket against being closed by other upgrade handlers
+      // (e.g. Next.js's NextCustomServer.upgradeHandler) that run synchronously
+      // after this async handler yields at the first await.
+      const guard = guardSocket(socket);
+
+      try {
+        const result = await handleUpgrade(request, socket, head);
+
+        // Restore original socket methods now that all synchronous
+        // upgrade handlers have finished
+        guard.restore();
+
+        if (!result.authorized) {
+          // rejectUpgrade already wrote the HTTP error response but
+          // socket.destroy() was a no-op during the guard; clean up now
+          socket.destroy();
+          return;
+        }
+
+        wss.handleUpgrade(result.request, socket, head, (ws) => {
+          wss.emit('connection', ws, result.request);
+          setupWSConnection(ws, result.request, { docName: result.pageId });
+        });
+      } catch (err) {
+        guard.restore();
+
+        logger.error('Yjs upgrade handler failed unexpectedly', { url, err });
+        if (socket.writable) {
+          socket.write('HTTP/1.1 500 Internal Server Error\r\n\r\n');
+        }
+        socket.destroy();
       }
       }
-
-      wss.handleUpgrade(result.request, socket, head, (ws) => {
-        wss.emit('connection', ws, result.request);
-        setupWSConnection(ws, result.request, { docName: result.pageId });
-      });
     });
     });
 
 
     logger.info('YjsService initialized with y-websocket');
     logger.info('YjsService initialized with y-websocket');

+ 13 - 21
packages/editor/src/client/stores/use-collaborative-editor-mode.ts

@@ -67,18 +67,21 @@ export const useCollaborativeEditorMode = (
     const { awareness } = _provider;
     const { awareness } = _provider;
     awareness.setLocalStateField('editors', userLocalState);
     awareness.setLocalStateField('editors', userLocalState);
 
 
-    const providerSyncHandler = (isSync: boolean) => {
-      if (isSync && onEditorsUpdated != null) {
-        const clientList: EditingClient[] = Array.from(
-          awareness.getStates().values(),
-          (value) => value.editors,
-        );
-        if (Array.isArray(clientList)) {
-          onEditorsUpdated(clientList);
-        }
+    const emitEditorList = () => {
+      if (onEditorsUpdated == null) return;
+      const clientList: EditingClient[] = Array.from(
+        awareness.getStates().values(),
+        (value) => value.editors,
+      );
+      if (Array.isArray(clientList)) {
+        onEditorsUpdated(clientList);
       }
       }
     };
     };
 
 
+    const providerSyncHandler = (isSync: boolean) => {
+      if (isSync) emitEditorList();
+    };
+
     _provider.on('sync', providerSyncHandler);
     _provider.on('sync', providerSyncHandler);
 
 
     const updateAwarenessHandler = (update: {
     const updateAwarenessHandler = (update: {
@@ -86,21 +89,10 @@ export const useCollaborativeEditorMode = (
       updated: number[];
       updated: number[];
       removed: number[];
       removed: number[];
     }) => {
     }) => {
-      // remove the states of disconnected clients
       for (const clientId of update.removed) {
       for (const clientId of update.removed) {
         awareness.getStates().delete(clientId);
         awareness.getStates().delete(clientId);
       }
       }
-
-      // update editor list
-      if (onEditorsUpdated != null) {
-        const clientList: EditingClient[] = Array.from(
-          awareness.getStates().values(),
-          (value) => value.editors,
-        );
-        if (Array.isArray(clientList)) {
-          onEditorsUpdated(clientList);
-        }
-      }
+      emitEditorList();
     };
     };
 
 
     awareness.on('update', updateAwarenessHandler);
     awareness.on('update', updateAwarenessHandler);