diff --git a/lib/safe-http-client.js b/lib/safe-http-client.js index 58093cd..cfe1170 100644 --- a/lib/safe-http-client.js +++ b/lib/safe-http-client.js @@ -1,8 +1,6 @@ import dns from 'node:dns/promises' - import {fetch} from 'undici' import ipaddr from 'ipaddr.js' - import {defaultMimeTypes} from './constants.js' export class HttpError extends Error { @@ -28,122 +26,119 @@ export class HttpError extends Error { } } -export class SafeHttpClient { - /** @param {number} [maxSize] */ - constructor(maxSize) { - this.maxSize = maxSize +/** + * Check that the URL is valid, the prototol is allowed, and that the host is + * a safe unicast address. + * + * @param {string} url + * URL to check. + * @returns {Promise} + * URL object. + */ +export async function checkUrl(url) { + // Throws if the URL is invalid + const validUrl = new URL(url) + const {protocol, hostname} = validUrl + + // Don't allow aother protocols like file:// URLs + if (!['http:', 'https:'].includes(protocol)) { + throw new Error('Bad protocol') } - /** - * Check if the URL is a valid URL or IP, that the prototol is valid, - * and the host is a safe unicast address. - * @param {string} url - */ - static async checkUrl(url) { - // Throws if the URL is invalid - const validUrl = new URL(url) - const {protocol, hostname} = validUrl - - // Don't allow aother protocols like file:// URLs - if (!['http:', 'https:'].includes(protocol)) { - throw new Error('Bad protocol') - } - - try { - var {address} = await dns.lookup(hostname) - } catch (err) { - throw new Error('Bad url host') - } - - /** - * Server Side Request Forgery (SSRF) Protection. - * - * SSRF is an attack where an attacker can trick a server into making unexpected network connections. - * This can lead to unauthorized access to internal resources, information disclosure, - * denial-of-service attacks, or even remote code execution. - * - * One common SSRF vector is tricking the server into making requests to internal IP addresses - * or to other services within the network that the server shouldn't be accessing. This can - * expose sensitive internal data or systems. - * - * Unicast addresses are typically used for communication between hosts on the public internet. - * By only allowing addresses in the 'unicast' range, we can prevent SSRF attacks targeting - * non-public IP ranges, such as private, multicast, and reserved IPs. - */ - if (ipaddr.process(address).range() !== 'unicast') { - throw new Error('Bad url host') - } - - return validUrl + try { + var {address} = await dns.lookup(hostname) + } catch (err) { + throw new Error('Bad url host') } /** - * Fetch a URL. + * Server Side Request Forgery (SSRF) Protection. + * + * SSRF is an attack where an attacker can trick a server into making unexpected network connections. + * This can lead to unauthorized access to internal resources, information disclosure, + * denial-of-service attacks, or even remote code execution. + * + * One common SSRF vector is tricking the server into making requests to internal IP addresses + * or to other services within the network that the server shouldn't be accessing. This can + * expose sensitive internal data or systems. * - * @param {URL | string} url - * URL. - * @param {import('undici').RequestInit} options - * Configuration, passed through to `fetch`. - * @returns {Promise<{buffer?: Buffer, headers: import('undici').Headers}>} - * Buffer of response (except when `HEAD`) and headers. + * Unicast addresses are typically used for communication between hosts on the public internet. + * By only allowing addresses in the 'unicast' range, we can prevent SSRF attacks targeting + * non-public IP ranges, such as private, multicast, and reserved IPs. */ - async safeFetch(url, options) { - let response = await fetch(url, options) - - // If there's a redirect, check the redirected URL for SSRF and then follow it if it's valid. - if ([301, 302, 303, 307, 308].includes(response.status)) { - const redirectedUrl = response.headers.get('location') + if (ipaddr.process(address).range() !== 'unicast') { + throw new Error('Bad url host') + } +} - if (!redirectedUrl) { - throw new HttpError(400, 'Missing `Location` header') - } +/** + * Fetch a URL. + * + * @param {URL | string} url + * URL. + * @param {import('undici').RequestInit} options + * Configuration, passed through to `fetch`. + * @param {number} [maxSize] + * The max size in bytes to download. + * @returns {Promise<{buffer?: Buffer, headers: import('undici').Headers}>} + * Buffer of response (except when `HEAD`) and headers. + */ +export async function safeFetch(url, options, maxSize) { + let response = await fetch(url, options) + + // If there's a redirect, check the redirected URL for SSRF and then follow it if it's valid. + if ([301, 302, 303, 307, 308].includes(response.status)) { + const redirectedUrl = response.headers.get('location') + + if (!redirectedUrl) { + throw new HttpError(400, 'Missing `Location` header') + } - await SafeHttpClient.checkUrl(redirectedUrl) + await checkUrl(redirectedUrl) - response = await fetch(redirectedUrl, { - ...options, - // Do not allow another redirect - redirect: 'error' - }) - } + response = await fetch(redirectedUrl, { + ...options, + // Do not allow another redirect + redirect: 'error' + }) + } - const contentType = response.headers.get('content-type') - if (!contentType) { - throw new HttpError(400, 'Empty content-type header') - } + const contentType = response.headers.get('content-type') + if (!contentType) { + throw new HttpError(400, 'Empty content-type header') + } - if (!defaultMimeTypes.includes(contentType)) { - throw new HttpError(400, 'Unsupported content-type returned') - } + if (!defaultMimeTypes.includes(contentType)) { + throw new HttpError(400, 'Unsupported content-type returned') + } - if (options.method === 'HEAD') { - return {headers: response.headers} - } + if (options.method === 'HEAD') { + return {headers: response.headers} + } - if (!response.body) { - throw new HttpError(400, 'No response body') - } + if (!response.body) { + throw new HttpError(400, 'No response body') + } - /** @type {Array} */ - const chunks = [] - const reader = response.body.getReader() - let currentByteLength = 0 + /** @type {Array} */ + const chunks = [] + const reader = response.body.getReader() + let currentByteLength = 0 - while (true) { - const {done, value} = await reader.read() - if (done) { - break - } - chunks.push(value) + while (true) { + const {done, value} = await reader.read() + if (done) { + break + } + chunks.push(value) - if (this.maxSize) { - currentByteLength += value.length - if (currentByteLength > this.maxSize) { - throw new HttpError(413, 'Content-Length exceeded') - } + if (maxSize) { + currentByteLength += value.length + if (currentByteLength > maxSize) { + throw new HttpError(413, 'Content-Length exceeded') } } - - return {buffer: Buffer.concat(chunks), headers: response.headers} } + + return {buffer: Buffer.concat(chunks), headers: response.headers} } diff --git a/lib/server.js b/lib/server.js index fc88f00..32756c1 100644 --- a/lib/server.js +++ b/lib/server.js @@ -1,12 +1,9 @@ -import http from 'node:http' -import net from 'node:net' import {EventEmitter} from 'node:events' import crypto from 'node:crypto' -import url from 'node:url' - +import http from 'node:http' +import net from 'node:net' import {Headers} from 'undici' - -import {SafeHttpClient, HttpError} from './safe-http-client.js' +import {checkUrl, safeFetch, HttpError} from './safe-http-client.js' import { securityHeaders, defaultRequestHeaders, @@ -64,9 +61,12 @@ export class Server extends EventEmitter { } /** - * Start the server. Identical to `net.Server.listen()`. + * Start the server. + * * @param {Parameters['listen']>} args - * @public + * Arguments passedf to `net.Server.listen`. + * @returns {net.Server} + * Server. */ listen(...args) { return http.createServer(this.handle.bind(this)).listen(...args) @@ -77,14 +77,13 @@ export class Server extends EventEmitter { * Integrate with your own server by calling this method and routing all requests to it. * @param {http.IncomingMessage} req * @param {http.ServerResponse} res - * @public */ async handle(req, res) { if (req.method !== 'GET' && req.method !== 'HEAD') { return this.write(res, 405, 'Method not allowed') } - const paths = url.parse(req.url || '')?.path?.split('/') + const paths = req.url?.split('/') if (!paths || paths.length < 3) { return this.write(res, 404, 'Malformed request') @@ -98,7 +97,7 @@ export class Server extends EventEmitter { } try { - var validUrl = await SafeHttpClient.checkUrl(decodedUrl) + await checkUrl(decodedUrl) } catch (err) { const exception = /** @type {Error} */ (err) return this.write(res, 400, exception.message) @@ -115,17 +114,20 @@ export class Server extends EventEmitter { // TODO: respect forwarded headers (check if not private IP) const filterRequestHeaders = filterHeaders(defaultRequestHeaders) const filterResponseHeaders = filterHeaders(defaultResponseHeaders) - const client = new SafeHttpClient(this.options.maxSize) - const {buffer, headers: resHeaders} = await client.safeFetch(validUrl, { - // @ts-expect-error: `IncomingHttpHeaders` can be passed to `Headers` - headers: filterRequestHeaders(new Headers(req.headers)), - method: req.method, - // We can't blindly follow redirects as the initial checkUrl - // might have been safe, but the redirect location might not be. - // SafeHttpClient will check the redirect location before following it. - redirect: 'manual', - signal - }) + const {buffer, headers: resHeaders} = await safeFetch( + decodedUrl, + { + // @ts-expect-error: `IncomingHttpHeaders` can be passed to `Headers` + headers: filterRequestHeaders(new Headers(req.headers)), + method: req.method, + // We can't blindly follow redirects as the initial checkUrl + // might have been safe, but the redirect location might not be. + // safeFetch will check the redirect location before following it. + redirect: 'manual', + signal + }, + this.options.maxSize + ) const headers = { ...securityHeaders, @@ -148,8 +150,8 @@ export class Server extends EventEmitter { if (err.name === 'AbortError') { return } - const msg = err.message || 'Internal server error' - return this.write(res, 500, msg) + console.error(err) + return this.write(res, 500, 'Internal server error') } } } @@ -161,7 +163,7 @@ export class Server extends EventEmitter { */ verifyHmac(receivedDigest, hex) { // Hex-decode the URL - const decodedUrl = Buffer.from(hex, 'hex').toString() + const decodedUrl = String(Buffer.from(hex, 'hex')) // Verify the HMAC digest to ensure the URL hasn't been tampered with const hmac = crypto.createHmac('sha1', this.options.secret) diff --git a/readme.md b/readme.md index 6b04d6b..82cfa87 100644 --- a/readme.md +++ b/readme.md @@ -70,11 +70,11 @@ A standalone server. ```js import {Server} from 'camomile' -const HMACKey = process.env.CAMOMILE_HMAC_KEY +const secret = process.env.CAMOMILE_SECRET -if (!HMACKey) throw new Error('Missing `CAMOMILE_HMAC_KEY` in environment') +if (!secret) throw new Error('Missing `CAMOMILE_SECRET` in environment') -const server = new Server({HMACKey}) +const server = new Server({secret}) server.listen({host: '127.0.0.1', port: 1080}) ``` @@ -88,7 +88,7 @@ There is no default export. Creates a new camomile server with options. -#### `options.HMACKey` +#### `options.secret` The HMAC key to decrypt the URLs and used by [`rehype-github-image`][] (`string`, required). @@ -112,14 +112,14 @@ if the resource is larger than the maximum size. import express from 'express' import {Server} from 'camomile' -const HMACKey = process.env.CAMOMILE_HMAC_KEY -if (!HMACKey) throw new Error('Missing `CAMOMILE_HMAC_KEY` in environment') +const secret = process.env.CAMOMILE_SECRET +if (!secret) throw new Error('Missing `CAMOMILE_SECRET` in environment') const host = '127.0.0.1' const port = 1080 const app = express() const uploadApp = express() -const camomile = new Server({HMACKey}) +const camomile = new Server({secret}) uploadApp.all('*', camomile.handle.bind(camomile)) app.use('/uploads', uploadApp) @@ -136,13 +136,13 @@ import url from 'node:url' import {Server} from 'camomile' import Koa from 'koa' -const HMACKey = process.env.CAMOMILE_HMAC_KEY -if (!HMACKey) throw new Error('Missing `CAMOMILE_HMAC_KEY` in environment') +const secret = process.env.CAMOMILE_SECRET +if (!secret) throw new Error('Missing `CAMOMILE_SECRET` in environment') const port = 1080 const app = new Koa() const appCallback = app.callback() -const camomile = new Server({HMACKey}) +const camomile = new Server({secret}) const server = http.createServer((req, res) => { const urlPath = url.parse(req.url || '').pathname || '' @@ -164,11 +164,11 @@ server.listen(port) import createFastify from 'fastify' import {Server} from 'camomile' -const HMACKey = process.env.CAMOMILE_HMAC_KEY -if (!HMACKey) throw new Error('Missing `CAMOMILE_HMAC_KEY` in environment') +const secret = process.env.CAMOMILE_SECRET +if (!secret) throw new Error('Missing `CAMOMILE_SECRET` in environment') const fastify = createFastify({logger: true}) -const camomile = new Server({HMACKey}) +const camomile = new Server({secret}) /** * Add `content-type` so fastify forewards without a parser to the leave body untouched. @@ -223,7 +223,7 @@ export const config = { } const camomile = new Server({ - HMACKey: process.env.CAMOMILE_HMAC_KEY, + secret: process.env.CAMOMILE_SECRET, }) export default function handler(req: NextApiRequest, res: NextApiResponse) {