diff --git a/src/stream-server.test.ts b/src/stream-server.test.ts new file mode 100644 index 00000000..096a4de2 --- /dev/null +++ b/src/stream-server.test.ts @@ -0,0 +1,64 @@ +import { describe, it, expect } from 'vitest'; +import { isAllowedOrigin } from './stream-server.js'; + +describe('isAllowedOrigin', () => { + describe('allowed origins', () => { + it('should allow connections with no origin (CLI tools)', () => { + expect(isAllowedOrigin(undefined)).toBe(true); + }); + + it('should allow empty string origin', () => { + expect(isAllowedOrigin('')).toBe(true); + }); + + it('should allow file:// origins', () => { + expect(isAllowedOrigin('file:///path/to/viewer.html')).toBe(true); + expect(isAllowedOrigin('file:///C:/Users/user/viewer.html')).toBe(true); + }); + + it('should allow http://localhost origins', () => { + expect(isAllowedOrigin('http://localhost')).toBe(true); + expect(isAllowedOrigin('http://localhost:3000')).toBe(true); + expect(isAllowedOrigin('http://localhost:9223')).toBe(true); + expect(isAllowedOrigin('http://localhost:8080')).toBe(true); + }); + + it('should allow https://localhost origins', () => { + expect(isAllowedOrigin('https://localhost')).toBe(true); + expect(isAllowedOrigin('https://localhost:3000')).toBe(true); + }); + + it('should allow http://127.0.0.1 origins', () => { + expect(isAllowedOrigin('http://127.0.0.1')).toBe(true); + expect(isAllowedOrigin('http://127.0.0.1:3000')).toBe(true); + expect(isAllowedOrigin('http://127.0.0.1:9223')).toBe(true); + }); + + it('should allow IPv6 loopback origins', () => { + expect(isAllowedOrigin('http://[::1]')).toBe(true); + expect(isAllowedOrigin('http://[::1]:3000')).toBe(true); + }); + }); + + describe('rejected origins', () => { + it('should reject remote origins', () => { + expect(isAllowedOrigin('https://evil.com')).toBe(false); + expect(isAllowedOrigin('http://attacker.local:8080')).toBe(false); + expect(isAllowedOrigin('https://example.com')).toBe(false); + }); + + it('should reject origins with localhost in path but not hostname', () => { + expect(isAllowedOrigin('https://evil.com/localhost')).toBe(false); + }); + + it('should reject origins that look like localhost but are not', () => { + expect(isAllowedOrigin('http://localhost.evil.com')).toBe(false); + expect(isAllowedOrigin('http://not-localhost:3000')).toBe(false); + }); + + it('should reject invalid origin URLs', () => { + expect(isAllowedOrigin('not-a-url')).toBe(false); + expect(isAllowedOrigin('://missing-scheme')).toBe(false); + }); + }); +}); diff --git a/src/stream-server.ts b/src/stream-server.ts index 83136d6b..6a0505c0 100644 --- a/src/stream-server.ts +++ b/src/stream-server.ts @@ -2,6 +2,33 @@ import { WebSocketServer, WebSocket } from 'ws'; import type { BrowserManager, ScreencastFrame } from './browser.js'; import { setScreencastFrameCallback } from './actions.js'; +/** + * Check whether a WebSocket connection origin should be allowed. + * Allows: no origin (CLI tools), file:// origins, and localhost/loopback origins. + * Rejects: all other origins (prevents malicious web pages from connecting). + */ +export function isAllowedOrigin(origin: string | undefined): boolean { + // Allow connections with no origin (non-browser clients like CLI tools) + if (!origin) { + return true; + } + // Allow file:// origins (local HTML files) + if (origin.startsWith('file://')) { + return true; + } + // Allow localhost/loopback origins (browser-based stream viewers) + try { + const url = new URL(origin); + const host = url.hostname; + if (host === 'localhost' || host === '127.0.0.1' || host === '::1' || host === '[::1]') { + return true; + } + } catch { + // Invalid origin URL - reject + } + return false; +} + // Message types for WebSocket communication export interface FrameMessage { type: 'frame'; @@ -89,21 +116,19 @@ export class StreamServer { try { this.wss = new WebSocketServer({ port: this.port, - // Security: Reject cross-origin WebSocket connections from browsers. + // Security: Reject cross-origin WebSocket connections from untrusted origins. // This prevents malicious web pages from connecting and injecting input events. + // Localhost origins are allowed so browser-based stream viewers can connect. verifyClient: (info: { origin: string; secure: boolean; req: import('http').IncomingMessage; }) => { - const origin = info.origin; - // Allow connections with no origin (non-browser clients like CLI tools) - // Reject connections from web pages (which always have an origin) - if (origin && !origin.startsWith('file://')) { - console.log(`[StreamServer] Rejected connection from origin: ${origin}`); - return false; + if (isAllowedOrigin(info.origin)) { + return true; } - return true; + console.log(`[StreamServer] Rejected connection from origin: ${info.origin}`); + return false; }, });