Skip to content

Commit

Permalink
Merge pull request #5 from Lordfirespeed/refactor-origin-handler
Browse files Browse the repository at this point in the history
Refactor origin handler
  • Loading branch information
talentlessguy authored Jul 12, 2024
2 parents 57cf95f + 78645fc commit 11b94a7
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .husky/pre-commit
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/sh
#!/bin/sh -l
. "$(dirname "$0")/_/husky.sh"

pnpm format && pnpm lint && pnpm build && pnpm test
84 changes: 66 additions & 18 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { IncomingMessage as Request, ServerResponse as Response } from 'http'
import { vary } from 'es-vary'

export interface AccessControlOptions {
origin?: string | boolean | ((req: Request, res: Response) => string) | Array<string> | RegExp
origin?: string | boolean | ((req: Request, res: Response) => string) | Iterable<string> | RegExp
methods?: string[]
allowedHeaders?: string[]
exposedHeaders?: string[]
Expand All @@ -12,6 +12,68 @@ export interface AccessControlOptions {
preflightContinue?: boolean
}

const isIterable = (obj: unknown): obj is Iterable<unknown> => typeof obj[Symbol.iterator] === 'function'

const failOriginParam = () => {
throw new TypeError('No other objects allowed. Allowed types is array of strings or RegExp')
}

const getOriginHeaderHandler = (origin: unknown): ((req: Request, res: Response) => void) => {
if (typeof origin === 'boolean') {
return origin
? (_, res) => {
res.setHeader('Access-Control-Allow-Origin', '*')
}
: () => undefined
}

if (typeof origin === 'string') {
return (_, res) => {
res.setHeader('Access-Control-Allow-Origin', origin)
}
}

if (typeof origin === 'function') {
return (req, res) => {
vary(res, 'Origin')
res.setHeader('Access-Control-Allow-Origin', origin(req, res))
}
}

if (typeof origin !== 'object') failOriginParam()

if (isIterable(origin)) {
const originArray = Array.from(origin)
if (originArray.some((element) => typeof element !== 'string')) failOriginParam()

const originSet = new Set(origin)

if (originSet.has('*')) {
return (_, res) => {
res.setHeader('Access-Control-Allow-Origin', '*')
}
}

return (req, res) => {
vary(res, 'Origin')
if (req.headers.origin === undefined) return
if (!originSet.has(req.headers.origin)) return
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
}
}

if (origin instanceof RegExp) {
return (req, res) => {
vary(res, 'Origin')
if (req.headers.origin === undefined) return
if (!origin.test(req.headers.origin)) return
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
}
}

failOriginParam()
}

/**
* CORS Middleware
*/
Expand All @@ -26,24 +88,10 @@ export const cors = (opts: AccessControlOptions = {}) => {
optionsSuccessStatus = 204,
preflightContinue = false
} = opts
const originHeaderHandler = getOriginHeaderHandler(origin)

return (req: Request, res: Response, next?: () => void) => {
// Checking the type of the origin property
if (typeof origin === 'boolean' && origin === true) {
res.setHeader('Access-Control-Allow-Origin', '*')
} else if (typeof origin === 'string') {
res.setHeader('Access-Control-Allow-Origin', origin)
} else if (typeof origin === 'function') {
res.setHeader('Access-Control-Allow-Origin', origin(req, res))
} else if (typeof origin === 'object') {
if (Array.isArray(origin) && (origin.indexOf(req.headers.origin) !== -1 || origin.indexOf('*') !== -1)) {
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
} else if (origin instanceof RegExp && origin.test(req.headers.origin)) {
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
} else {
throw new TypeError('No other objects allowed. Allowed types is array of strings or RegExp')
}
}
if ((typeof origin === 'string' && origin !== '*') || typeof origin === 'function') vary(res, 'Origin')
originHeaderHandler(req, res)

// Setting the Access-Control-Allow-Methods header from the methods array
res.setHeader('Access-Control-Allow-Methods', methods.join(', ').toUpperCase())
Expand Down
40 changes: 32 additions & 8 deletions tests/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,44 @@ describe('CORS headers tests', (it) => {
'http://example.com'
)
})
it('should set origin if it is an array', async () => {
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))
describe('when origin is an array of strings', (it) => {
it('should set origin when origin header is included in request and whitelisted', async () => {
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))

const fetch = makeFetch(app)
const fetch = makeFetch(app)

await fetch('/', { headers: { Origin: 'http://example.com' } }).expect(
'Access-Control-Allow-Origin',
'http://example.com'
)
await fetch('/', { headers: { Origin: 'http://example.com' } }).expect(
'Access-Control-Allow-Origin',
'http://example.com'
)
})
it('should not set origin when origin header is included in request but not whitelisted', async () => {
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))

const fetch = makeFetch(app)

await fetch('/', { headers: { Origin: 'http://not-example.com' } }).expect('Access-Control-Allow-Origin', null)
})
it('should not set origin when origin header is excluded from request', async () => {
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))

const fetch = makeFetch(app)

await fetch('/').expect('Access-Control-Allow-Origin', null)
})
})
it('should send an error if origin is an iterable containing a non-string', async () => {
try {
// @ts-ignore
const middleware = cors({ origin: [{}, 3, 'abc'] })
} catch (e) {
assert.strictEqual(e.message, 'No other objects allowed. Allowed types is array of strings or RegExp')
}
})
it('should send an error if it is other object types', () => {
try {
// @ts-ignore
const app = createServer(cors({ origin: { site: 'http://example.com' } }))
const middleware = cors({ origin: { site: 'http://example.com' } })
} catch (e) {
assert.strictEqual(e.message, 'No other objects allowed. Allowed types is array of strings or RegExp')
}
Expand Down

0 comments on commit 11b94a7

Please sign in to comment.