guard-socket.spec.ts 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import http from 'node:http';
  2. import WebSocket, { WebSocketServer } from 'ws';
  3. import { docs, setPersistence, setupWSConnection } from 'y-websocket/bin/utils';
  4. import { guardSocket } from './guard-socket';
  5. /**
  6. * Creates a test server where:
  7. * 1. The Yjs upgrade handler guards the socket and awaits before completing
  8. * 2. A hostile handler (simulating Next.js) calls socket.end() for /yjs/ paths
  9. */
  10. const createServerWithHostileHandler = (): {
  11. server: http.Server;
  12. wss: WebSocketServer;
  13. } => {
  14. const server = http.createServer();
  15. const wss = new WebSocketServer({ noServer: true });
  16. // Yjs handler (registered first — same order as production)
  17. server.on('upgrade', async (request, socket, head) => {
  18. const url = request.url ?? '';
  19. if (!url.startsWith('/yjs/')) return;
  20. const pageId = url.slice('/yjs/'.length).split('?')[0];
  21. const guard = guardSocket(socket);
  22. try {
  23. // Simulate async auth delay
  24. await new Promise((resolve) => setTimeout(resolve, 10));
  25. guard.restore();
  26. wss.handleUpgrade(request, socket, head, (ws) => {
  27. wss.emit('connection', ws, request);
  28. setupWSConnection(ws, request, { docName: pageId });
  29. });
  30. } catch {
  31. guard.restore();
  32. socket.destroy();
  33. }
  34. });
  35. // Hostile handler (registered second — simulates Next.js upgradeHandler)
  36. server.on('upgrade', (_request, socket) => {
  37. socket.end();
  38. });
  39. return { server, wss };
  40. };
  41. const connectClient = (port: number, pageId: string): Promise<WebSocket> => {
  42. return new Promise((resolve, reject) => {
  43. const ws = new WebSocket(`ws://127.0.0.1:${port}/yjs/${pageId}`);
  44. ws.binaryType = 'arraybuffer';
  45. ws.on('open', () => resolve(ws));
  46. ws.on('error', reject);
  47. });
  48. };
  49. describe('guardSocket — protection against hostile upgrade handlers', () => {
  50. let server: http.Server;
  51. let wss: WebSocketServer;
  52. let port: number;
  53. beforeAll(async () => {
  54. setPersistence(null);
  55. const testServer = createServerWithHostileHandler();
  56. server = testServer.server;
  57. wss = testServer.wss;
  58. await new Promise<void>((resolve) => {
  59. server.listen(0, '127.0.0.1', () => {
  60. const addr = server.address();
  61. if (addr && typeof addr === 'object') {
  62. port = addr.port;
  63. }
  64. resolve();
  65. });
  66. });
  67. });
  68. afterAll(async () => {
  69. for (const [name, doc] of docs) {
  70. doc.destroy();
  71. docs.delete(name);
  72. }
  73. await new Promise<void>((resolve) => {
  74. wss.close(() => {
  75. server.close(() => resolve());
  76. });
  77. });
  78. });
  79. afterEach(() => {
  80. for (const [name, doc] of docs) {
  81. doc.destroy();
  82. docs.delete(name);
  83. }
  84. });
  85. it('should establish WebSocket connection even when a hostile handler calls socket.end()', async () => {
  86. const pageId = 'guard-test-001';
  87. const ws = await connectClient(port, pageId);
  88. await new Promise((resolve) => setTimeout(resolve, 50));
  89. const serverDoc = docs.get(pageId);
  90. expect(serverDoc).toBeDefined();
  91. assert(serverDoc !== undefined);
  92. expect(serverDoc.conns.size).toBe(1);
  93. ws.close();
  94. });
  95. it('should handle multiple concurrent connections with hostile handler', async () => {
  96. const pageId = 'guard-test-002';
  97. const connections = await Promise.all([
  98. connectClient(port, pageId),
  99. connectClient(port, pageId),
  100. ]);
  101. await new Promise((resolve) => setTimeout(resolve, 50));
  102. const serverDoc = docs.get(pageId);
  103. expect(serverDoc).toBeDefined();
  104. assert(serverDoc !== undefined);
  105. expect(serverDoc.conns.size).toBe(2);
  106. for (const ws of connections) {
  107. ws.close();
  108. }
  109. });
  110. it('should properly restore socket methods after guard', async () => {
  111. const pageId = 'guard-test-003';
  112. const ws = await connectClient(port, pageId);
  113. await new Promise((resolve) => setTimeout(resolve, 50));
  114. // Connection succeeds, meaning socket.end/destroy were properly
  115. // guarded during async auth and restored before wss.handleUpgrade
  116. expect(ws.readyState).toBe(WebSocket.OPEN);
  117. ws.close();
  118. await new Promise((resolve) => setTimeout(resolve, 50));
  119. // After close, the server doc should have removed the connection
  120. const serverDoc = docs.get(pageId);
  121. if (serverDoc) {
  122. expect(serverDoc.conns.size).toBe(0);
  123. }
  124. });
  125. });