From ecfe3cfc46b41d5b7b89a9f541bac32bc99b15fb Mon Sep 17 00:00:00 2001 From: Ruben Verborgh Date: Tue, 1 Dec 2020 21:24:43 +0100 Subject: [PATCH] feat: Support the Forwarded header. --- src/ldp/UnsecureWebSocketsProtocol.ts | 10 +-- src/ldp/http/BasicTargetExtractor.ts | 29 +++++--- src/util/HeaderUtil.ts | 34 +++++++++ test/integration/WebSocketsProtocol.test.ts | 23 ++++--- .../ldp/UnsecureWebSocketsProtocol.test.ts | 69 +++++++++++-------- .../ldp/http/BasicTargetExtractor.test.ts | 18 +++++ test/unit/util/HeaderUtil.test.ts | 32 +++++++++ 7 files changed, 162 insertions(+), 53 deletions(-) diff --git a/src/ldp/UnsecureWebSocketsProtocol.ts b/src/ldp/UnsecureWebSocketsProtocol.ts index 7f0871ef8..65df58142 100644 --- a/src/ldp/UnsecureWebSocketsProtocol.ts +++ b/src/ldp/UnsecureWebSocketsProtocol.ts @@ -3,6 +3,7 @@ import type WebSocket from 'ws'; import { getLoggerFor } from '../logging/LogUtil'; import type { HttpRequest } from '../server/HttpRequest'; import { WebSocketHandler } from '../server/WebSocketHandler'; +import { parseForwarded } from '../util/HeaderUtil'; import type { ResourceIdentifier } from './representation/ResourceIdentifier'; const VERSION = 'solid/0.1.0-alpha'; @@ -26,13 +27,13 @@ class WebSocketListener extends EventEmitter { socket.addListener('message', (message: string): void => this.onMessage(message)); } - public start(upgradeRequest: HttpRequest): void { + public start({ headers, socket }: HttpRequest): void { // Greet the client this.sendMessage('protocol', VERSION); this.sendMessage('warning', 'Unstandardized protocol version, proceed with care'); // Verify the WebSocket protocol version - const protocolHeader = upgradeRequest.headers['sec-websocket-protocol']; + const protocolHeader = headers['sec-websocket-protocol']; if (!protocolHeader) { this.sendMessage('warning', `Missing Sec-WebSocket-Protocol header, expected value '${VERSION}'`); } else { @@ -44,8 +45,9 @@ class WebSocketListener extends EventEmitter { } // Store the HTTP host and protocol - this.host = upgradeRequest.headers.host ?? ''; - this.protocol = (upgradeRequest.socket as any).secure ? 'https:' : 'http:'; + const forwarded = parseForwarded(headers.forwarded); + this.host = forwarded.host ?? headers.host ?? 'localhost'; + this.protocol = forwarded.proto === 'https' || (socket as any).secure ? 'https:' : 'http:'; } private stop(): void { diff --git a/src/ldp/http/BasicTargetExtractor.ts b/src/ldp/http/BasicTargetExtractor.ts index 84ed0cdcc..ae9b55fba 100644 --- a/src/ldp/http/BasicTargetExtractor.ts +++ b/src/ldp/http/BasicTargetExtractor.ts @@ -1,6 +1,6 @@ import type { TLSSocket } from 'tls'; -import { getLoggerFor } from '../../logging/LogUtil'; import type { HttpRequest } from '../../server/HttpRequest'; +import { parseForwarded } from '../../util/HeaderUtil'; import { toCanonicalUriPath } from '../../util/PathUtil'; import type { ResourceIdentifier } from '../representation/ResourceIdentifier'; import { TargetExtractor } from './TargetExtractor'; @@ -11,28 +11,35 @@ import { TargetExtractor } from './TargetExtractor'; * TODO: input requires more extensive cleaning/parsing based on headers (see #22). */ export class BasicTargetExtractor extends TargetExtractor { - protected readonly logger = getLoggerFor(this); - - public async handle({ url, headers: { host }, connection }: HttpRequest): Promise { + public async handle({ url, connection, headers }: HttpRequest): Promise { if (!url) { - this.logger.error('The request has no URL'); throw new Error('Missing URL'); } + + // Extract host and protocol (possibly overridden by the Forwarded header) + let { host } = headers; + let protocol = (connection as TLSSocket)?.encrypted ? 'https' : 'http'; + if (headers.forwarded) { + const forwarded = parseForwarded(headers.forwarded); + if (forwarded.host) { + ({ host } = forwarded); + } + if (forwarded.proto) { + ({ proto: protocol } = forwarded); + } + } + + // Perform a sanity check on the host if (!host) { - this.logger.error('The request has no Host header'); throw new Error('Missing Host header'); } if (/[/\\*]/u.test(host)) { throw new Error(`The request has an invalid Host header: ${host}`); } - const isHttps = (connection as TLSSocket)?.encrypted; - this.logger.debug(`Request is using HTTPS: ${isHttps}`); - // URL object applies punycode encoding to domain - const base = `http${isHttps ? 's' : ''}://${host}`; + const base = `${protocol}://${host}`; const path = new URL(toCanonicalUriPath(url), base).href; - return { path }; } } diff --git a/src/util/HeaderUtil.ts b/src/util/HeaderUtil.ts index f97c75197..2aac2b947 100644 --- a/src/util/HeaderUtil.ts +++ b/src/util/HeaderUtil.ts @@ -378,3 +378,37 @@ export const addHeader = (response: HttpResponse, name: string, value: string | } response.setHeader(name, allValues.length === 1 ? allValues[0] : allValues); }; + +/** + * The Forwarded header from RFC7239 + */ +export interface Forwarded { + /** The user-agent facing interface of the proxy */ + by?: string; + /** The node making the request to the proxy */ + for?: string; + /** The host request header field as received by the proxy */ + host?: string; + /** The protocol used to make the request */ + proto?: string; +} + +/** + * Parses a Forwarded header value. + * + * @param value - The Forwarded header value. + * + * @returns The parsed Forwarded header. + */ +export const parseForwarded = (value = ''): Forwarded => { + const forwarded: Record = {}; + if (value) { + for (const pair of value.replace(/\s*,.*$/u, '').split(';')) { + const components = /^(by|for|host|proto)=(.+)$/u.exec(pair); + if (components) { + forwarded[components[1]] = components[2]; + } + } + } + return forwarded; +}; diff --git a/test/integration/WebSocketsProtocol.test.ts b/test/integration/WebSocketsProtocol.test.ts index c679df58c..b12094d09 100644 --- a/test/integration/WebSocketsProtocol.test.ts +++ b/test/integration/WebSocketsProtocol.test.ts @@ -6,15 +6,16 @@ import { instantiateFromConfig } from '../configs/Util'; const port = 6001; const serverUrl = `http://localhost:${port}/`; +const headers = { forwarded: 'host=example.pod;proto=https' }; -describe('A server with the Solid WebSockets API', (): void => { +describe('A server with the Solid WebSockets API behind a proxy', (): void => { let server: Server; beforeAll(async(): Promise => { const factory = await instantiateFromConfig( 'urn:solid-server:default:ServerFactory', 'websockets.json', { 'urn:solid-server:default:variable:port': port, - 'urn:solid-server:default:variable:baseUrl': 'http://example.pod/', + 'urn:solid-server:default:variable:baseUrl': 'https://example.pod/', }, ) as HttpServerFactory; server = factory.startServer(port); @@ -27,17 +28,17 @@ describe('A server with the Solid WebSockets API', (): void => { }); it('returns a 200.', async(): Promise => { - const response = await fetch(serverUrl, { headers: { host: 'example.pod' }}); + const response = await fetch(serverUrl, { headers }); expect(response.status).toBe(200); }); it('sets the Updates-Via header.', async(): Promise => { - const response = await fetch(serverUrl, { headers: { host: 'example.pod' }}); - expect(response.headers.get('Updates-Via')).toBe('ws://example.pod/'); + const response = await fetch(serverUrl, { headers }); + expect(response.headers.get('Updates-Via')).toBe('wss://example.pod/'); }); it('exposes the Updates-Via header via CORS.', async(): Promise => { - const response = await fetch(serverUrl, { headers: { host: 'example.pod' }}); + const response = await fetch(serverUrl, { headers }); expect(response.headers.get('Access-Control-Expose-Headers')!.split(',')) .toContain('Updates-Via'); }); @@ -47,7 +48,7 @@ describe('A server with the Solid WebSockets API', (): void => { const messages = new Array(); beforeAll(async(): Promise => { - client = new WebSocket(`ws://localhost:${port}`, [ 'solid/0.1.0-alpha' ], { headers: { host: 'example.pod' }}); + client = new WebSocket(`ws://localhost:${port}`, [ 'solid/0.1.0-alpha' ], { headers }); client.on('message', (message: string): any => messages.push(message)); await new Promise((resolve): any => client.on('open', resolve)); }); @@ -69,24 +70,24 @@ describe('A server with the Solid WebSockets API', (): void => { describe('when the client subscribes to a resource', (): void => { beforeAll(async(): Promise => { - client.send(`sub http://example.pod/my-resource`); + client.send(`sub https://example.pod/my-resource`); await new Promise((resolve): any => client.once('message', resolve)); }); it('acknowledges the subscription.', async(): Promise => { - expect(messages).toEqual([ `ack http://example.pod/my-resource` ]); + expect(messages).toEqual([ `ack https://example.pod/my-resource` ]); }); it('notifies the client of resource updates.', async(): Promise => { await fetch(`${serverUrl}my-resource`, { method: 'PUT', headers: { - host: 'example.pod', + ...headers, 'content-type': 'application/json', }, body: '{}', }); - expect(messages).toEqual([ `pub http://example.pod/my-resource` ]); + expect(messages).toEqual([ `pub https://example.pod/my-resource` ]); }); }); }); diff --git a/test/unit/ldp/UnsecureWebSocketsProtocol.test.ts b/test/unit/ldp/UnsecureWebSocketsProtocol.test.ts index cf5befd42..0009b13e1 100644 --- a/test/unit/ldp/UnsecureWebSocketsProtocol.test.ts +++ b/test/unit/ldp/UnsecureWebSocketsProtocol.test.ts @@ -120,52 +120,67 @@ describe('An UnsecureWebSocketsProtocol', (): void => { }); it('unsubscribes when a socket closes.', async(): Promise => { - const newSocket = new DummySocket(); - await protocol.handle({ webSocket: newSocket, upgradeRequest: { headers: {}, socket: {}}} as any); - expect(newSocket.listenerCount('message')).toBe(1); - newSocket.emit('close'); - expect(newSocket.listenerCount('message')).toBe(0); - expect(newSocket.listenerCount('close')).toBe(0); - expect(newSocket.listenerCount('error')).toBe(0); + const webSocket = new DummySocket(); + await protocol.handle({ webSocket, upgradeRequest: { headers: {}, socket: {}}} as any); + expect(webSocket.listenerCount('message')).toBe(1); + webSocket.emit('close'); + expect(webSocket.listenerCount('message')).toBe(0); + expect(webSocket.listenerCount('close')).toBe(0); + expect(webSocket.listenerCount('error')).toBe(0); }); it('unsubscribes when a socket errors.', async(): Promise => { - const newSocket = new DummySocket(); - await protocol.handle({ webSocket: newSocket, upgradeRequest: { headers: {}, socket: {}}} as any); - expect(newSocket.listenerCount('message')).toBe(1); - newSocket.emit('error'); - expect(newSocket.listenerCount('message')).toBe(0); - expect(newSocket.listenerCount('close')).toBe(0); - expect(newSocket.listenerCount('error')).toBe(0); + const webSocket = new DummySocket(); + await protocol.handle({ webSocket, upgradeRequest: { headers: {}, socket: {}}} as any); + expect(webSocket.listenerCount('message')).toBe(1); + webSocket.emit('error'); + expect(webSocket.listenerCount('message')).toBe(0); + expect(webSocket.listenerCount('close')).toBe(0); + expect(webSocket.listenerCount('error')).toBe(0); }); it('emits a warning when no Sec-WebSocket-Protocol is supplied.', async(): Promise => { - const newSocket = new DummySocket(); + const webSocket = new DummySocket(); const upgradeRequest = { headers: {}, socket: {}, } as any as HttpRequest; - await protocol.handle({ webSocket: newSocket, upgradeRequest } as any); - expect(newSocket.messages).toHaveLength(3); - expect(newSocket.messages.pop()) + await protocol.handle({ webSocket, upgradeRequest } as any); + expect(webSocket.messages).toHaveLength(3); + expect(webSocket.messages.pop()) .toBe('warning Missing Sec-WebSocket-Protocol header, expected value \'solid/0.1.0-alpha\''); - expect(newSocket.close).toHaveBeenCalledTimes(0); + expect(webSocket.close).toHaveBeenCalledTimes(0); }); it('emits an error and closes the connection with the wrong Sec-WebSocket-Protocol.', async(): Promise => { - const newSocket = new DummySocket(); + const webSocket = new DummySocket(); const upgradeRequest = { headers: { 'sec-websocket-protocol': 'solid/1.0.0, other', }, socket: {}, } as any as HttpRequest; - await protocol.handle({ webSocket: newSocket, upgradeRequest } as any); - expect(newSocket.messages).toHaveLength(3); - expect(newSocket.messages.pop()).toBe('error Client does not support protocol solid/0.1.0-alpha'); - expect(newSocket.close).toHaveBeenCalledTimes(1); - expect(newSocket.listenerCount('message')).toBe(0); - expect(newSocket.listenerCount('close')).toBe(0); - expect(newSocket.listenerCount('error')).toBe(0); + await protocol.handle({ webSocket, upgradeRequest } as any); + expect(webSocket.messages).toHaveLength(3); + expect(webSocket.messages.pop()).toBe('error Client does not support protocol solid/0.1.0-alpha'); + expect(webSocket.close).toHaveBeenCalledTimes(1); + expect(webSocket.listenerCount('message')).toBe(0); + expect(webSocket.listenerCount('close')).toBe(0); + expect(webSocket.listenerCount('error')).toBe(0); + }); + + it('respects the Forwarded header.', async(): Promise => { + const webSocket = new DummySocket(); + const upgradeRequest = { + headers: { + forwarded: 'proto=https;host=other.example', + 'sec-websocket-protocol': 'solid/0.1.0-alpha', + }, + socket: {}, + } as any as HttpRequest; + await protocol.handle({ webSocket, upgradeRequest } as any); + webSocket.emit('message', 'sub https://other.example/protocol/foo'); + expect(webSocket.messages).toHaveLength(3); + expect(webSocket.messages.pop()).toBe('ack https://other.example/protocol/foo'); }); }); diff --git a/test/unit/ldp/http/BasicTargetExtractor.test.ts b/test/unit/ldp/http/BasicTargetExtractor.test.ts index ad61f7751..0520f4055 100644 --- a/test/unit/ldp/http/BasicTargetExtractor.test.ts +++ b/test/unit/ldp/http/BasicTargetExtractor.test.ts @@ -51,4 +51,22 @@ describe('A BasicTargetExtractor', (): void => { await expect(extractor.handle({ url: '/', headers: { host: '點看' }} as any)) .resolves.toEqual({ path: 'http://xn--c1yn36f/' }); }); + + it('ignores an irrelevant Forwarded header.', async(): Promise => { + const headers = { + host: 'test.com', + forwarded: 'by=203.0.113.60', + }; + await expect(extractor.handle({ url: '/foo/bar', headers } as any)) + .resolves.toEqual({ path: 'http://test.com/foo/bar' }); + }); + + it('takes the Forwarded header into account.', async(): Promise => { + const headers = { + host: 'test.com', + forwarded: 'proto=https;host=pod.example', + }; + await expect(extractor.handle({ url: '/foo/bar', headers } as any)) + .resolves.toEqual({ path: 'https://pod.example/foo/bar' }); + }); }); diff --git a/test/unit/util/HeaderUtil.test.ts b/test/unit/util/HeaderUtil.test.ts index 98a0052ae..865e9f643 100644 --- a/test/unit/util/HeaderUtil.test.ts +++ b/test/unit/util/HeaderUtil.test.ts @@ -5,6 +5,7 @@ import { parseAcceptCharset, parseAcceptEncoding, parseAcceptLanguage, + parseForwarded, } from '../../../src/util/HeaderUtil'; describe('HeaderUtil', (): void => { @@ -166,4 +167,35 @@ describe('HeaderUtil', (): void => { expect(response.getHeader('names')).toEqual([ 'oldValue1', 'oldValue2', 'value1', 'values2' ]); }); }); + + describe('parseForwarded', (): void => { + it('parses an undefined value.', (): void => { + expect(parseForwarded()).toEqual({}); + }); + + it('parses an empty string.', (): void => { + expect(parseForwarded('')).toEqual({}); + }); + + it('parses a Forwarded header value.', (): void => { + expect(parseForwarded('for=192.0.2.60;proto=http;by=203.0.113.43;host=example.org')).toEqual({ + by: '203.0.113.43', + for: '192.0.2.60', + host: 'example.org', + proto: 'http', + }); + }); + + it('skips empty fields.', (): void => { + expect(parseForwarded('for=192.0.2.60;proto=;by=;host=')).toEqual({ + for: '192.0.2.60', + }); + }); + + it('takes only the first value into account.', (): void => { + expect(parseForwarded('host=pod.example, for=192.0.2.43, host=other')).toEqual({ + host: 'pod.example', + }); + }); + }); });