diff --git a/src/migration.ts b/src/migration.ts index 4012408..cb97ba5 100644 --- a/src/migration.ts +++ b/src/migration.ts @@ -1,4 +1,5 @@ import { promises as fsp } from "fs"; +import { relative } from "path"; import { VALID_FILE_REGEX } from "./current"; import { calculateHash } from "./hash"; @@ -118,73 +119,86 @@ export function compilePlaceholders( )(content); } +async function realpathOrNull(path: string): Promise { + try { + return await fsp.realpath(path); + } catch (e) { + return null; + } +} + export async function compileIncludes( parsedSettings: ParsedSettings, content: string, - processedFiles: Array = [], + processedFiles: ReadonlySet = new Set(), ): Promise { - const regex = /--!include (.*.sql)/g; - let compiledContent = content; - let match = regex.exec(content); - const includePath = `${parsedSettings.migrationsFolder}/fixtures/`; - let realPath; + const regex = /^--!include\s+(.*\.sql)\s*$/gm; - //if the fixtures folder isn't defined, catch the error and return the original content. - try { - realPath = await fsp.realpath(includePath); - } catch (e) { - if (!realPath) { - parsedSettings.logger.warn(`Warning: ${includePath} is not defined.`); - return content; - } - } + // Don't need to validate this unless an include happens. MUST end in a `/` + const fixturesPath = `${parsedSettings.migrationsFolder}/fixtures/`; - if (match) { - while (match != null) { - //make sure the include path starts with the real path of the fixtures folder. - let includeRegex; - let includeRealPath; + // Find all includes in this `content` + const matches = content.matchAll(regex); - try { - includeRegex = new RegExp(`^${realPath}`); - includeRealPath = await fsp.realpath(`${includePath}${match[1]}`); - } catch (e) { - throw new Error( - `include path not in ${parsedSettings.migrationsFolder}/fixtures/`, - ); - } + // Go through these matches and resolve their full paths, checking they are allowed + const sqlPathByRawSqlPath: Record = Object.create(null); + for (const match of matches) { + const [, rawSqlPath] = match; + const sqlPath = await realpathOrNull(`${fixturesPath}${rawSqlPath}`); - if (includeRegex.exec(includeRealPath) === null) { - throw new Error( - `include path not in ${parsedSettings.migrationsFolder}/fixtures/`, - ); - } + if (!sqlPath) { + throw new Error( + `Include of '${rawSqlPath}' failed because '${fixturesPath}${rawSqlPath}' doesn't seem to exist?`, + ); + } - //If we've already processed this file, skip it (prevents infinite chains) - if (!processedFiles.includes(includeRealPath)) { - processedFiles.push(includeRealPath); - const fileContents = await fsp.readFile(includeRealPath, "utf8"); - compiledContent = compiledContent.replace( - match[0], - fileContents.replace(/\$/g, "$$$$"), - ); - match = regex.exec(content); - } else { - //remove recursive include and continue - compiledContent = compiledContent.replace(match[0], ""); - match = regex.exec(content); - } + if (processedFiles.has(sqlPath)) { + throw new Error( + `Circular include detected - '${sqlPath}' is included again! Trace:\n ${[...processedFiles].reverse().join("\n ")}`, + ); } - //recursively call compileIncludes to catch includes in the included files. - return await compileIncludes( - parsedSettings, - compiledContent, - processedFiles, - ); - } else { - return compiledContent; + const relativePath = relative(fixturesPath, sqlPath); + if (relativePath.startsWith("..")) { + throw new Error( + `Forbidden: cannot include path '${sqlPath}' because it's not inside '${fixturesPath}'`, + ); + } + + // Looks good to me + sqlPathByRawSqlPath[rawSqlPath] = sqlPath; } + + // For the unique set of paths, load the file and then recursively do its own includes + const distinctSqlPaths = [...new Set(Object.values(sqlPathByRawSqlPath))]; + const contentsForDistinctSqlPaths = await Promise.all( + distinctSqlPaths.map(async (sqlPath) => { + const fileContents = await fsp.readFile(sqlPath, "utf8"); + const processed = await compileIncludes( + parsedSettings, + fileContents, + new Set([...processedFiles, sqlPath]), + ); + return processed; + }), + ); + + // Turn the results into a map for ease of lookup + const contentBySqlPath: Record = Object.create(null); + for (let i = 0, l = distinctSqlPaths.length; i < l; i++) { + const sqlPath = distinctSqlPaths[i]; + const content = contentsForDistinctSqlPaths[i]; + contentBySqlPath[sqlPath] = content; + } + + // Simple string replacement for each path matched + const compiledContent = content.replace(regex, (_match, rawSqlPath) => { + const sqlPath = sqlPathByRawSqlPath[rawSqlPath]; + const content = contentBySqlPath[sqlPath]; + return content; + }); + + return compiledContent; } const TABLE_CHECKS = {