diff --git a/src/ast.ts b/src/ast.ts index 4b36fce..bc3511d 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -67,7 +67,7 @@ export class MemberExpression extends Node { } export class ArrayExpression extends Node { - constructor(public members: Node[], public length: Node | null) { + constructor(public type: Type, public members: Node[]) { super() } } diff --git a/src/generator.ts b/src/generator.ts index 7b4bb38..22b9b50 100644 --- a/src/generator.ts +++ b/src/generator.ts @@ -59,13 +59,22 @@ function format(node: AST | null): string { const qualifiers = node.qualifiers.length ? `${node.qualifiers.join(' ')} ` : '' + let type = format(node.type) + let body = '' if (node.declarations.length) { const members: string[] = [] for (const declaration of node.declarations) { let value = '' - if (declaration.value instanceof ArrayExpression && !declaration.value.members.length) { - value = `[${format(declaration.value.length)}]` + + if (declaration.value instanceof ArrayExpression) { + const t = declaration.value.type + const params = t.parameters ? t.parameters?.map(format).join(', ') : '' + value = `[${params}]` + + if (declaration.value.members.length) { + value += ` = ${type}[${params}](${declaration.value.members.map(format).join(', ')})` + } } else if (declaration.value) { value = ` = ${format(declaration.value)}` } @@ -75,7 +84,7 @@ function format(node: AST | null): string { body = members.join(', ') } - line = `${layout}${qualifiers}${format(node.type)} ${body};\n`.trimStart() + line = `${layout}${qualifiers}${type} ${body};\n`.trimStart() } else if (node instanceof FunctionDeclaration) { const qualifiers = node.qualifiers.length ? `${node.qualifiers.join(' ')} ` : '' const args = node.args.map((node) => format(node).replace(';\n', '')).join(', ') @@ -86,7 +95,8 @@ function format(node: AST | null): string { } else if (node instanceof MemberExpression) { line = `${format(node.object)}.${format(node.property)}` } else if (node instanceof ArrayExpression) { - // TODO + const params = node.type.parameters ? node.type.parameters?.map(format).join(', ') : '' + line = `${node.type.name}[${params}](${node.members.map(format).join(', ')})` } else if (node instanceof IfStatement) { const consequent = format(node.consequent).replace(EOL_REGEX, '') const alternate = node.alternate ? ` else ${format(node.alternate).replace(EOL_REGEX, '')}` : '' diff --git a/src/parser.ts b/src/parser.ts index b7c6efe..b26008f 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -191,7 +191,7 @@ function parseExpression(body: Token[]): AST | null { } return new CallExpression(callee, args) - } else if (second.value === '.' || second.value === '[') { + } else if (second.value === '.') { const object = new Identifier(first.value) const property = parseExpression([body[2]])! const left = new MemberExpression(object, property) @@ -204,6 +204,29 @@ function parseExpression(body: Token[]): AST | null { } return left + } else if (second.value === '[') { + let i = 2 + + const type = new Type(first.value, []) + + if (body[i].value !== ']') type.parameters!.push(parseExpression([body[i++]]) as any) + i++ // skip ] + + const scope = readUntil(')', body, i).slice(1, -1) + + const members: AST[] = [] + + let j = 0 + while (j < scope.length) { + const next = readUntil(',', scope, j) + j += next.length + + if (next[next.length - 1].value === ',') next.pop() + + members.push(parseExpression(next)!) + } + + return new ArrayExpression(type, members) } } @@ -230,7 +253,7 @@ function parseVariable( let prefix: AST | null = null if (body[j].value === '[') { j++ // skip [ - prefix = new ArrayExpression([], parseExpression([body[j++]])) + prefix = new ArrayExpression(new Type(type.name, [parseExpression([body[j++]]) as any]), []) j++ // skip ] } @@ -271,7 +294,7 @@ function parseFunction(qualifiers: string[]): FunctionDeclaration { let prefix: AST | null = null if (line[0]?.value === '[') { line.shift() // skip [ - prefix = new ArrayExpression([], parseExpression([line.shift()!])) + prefix = new ArrayExpression(new Type(type.name, [parseExpression([line.shift()!]) as any]), []) line.shift() // skip ] } @@ -459,7 +482,7 @@ function parseStatements(): AST[] { else if (token.value === 'do') statement = parseDoWhile() else if (token.value === 'switch') statement = parseSwitch() else if (token.value === 'precision') statement = parsePrecision() - else if (isDeclaration(token.value)) statement = parseIndeterminate() + else if (isDeclaration(token.value) && tokens[i].value !== '[') statement = parseIndeterminate() } if (statement) { @@ -534,7 +557,7 @@ method(true); foo.bar(); -uniform float test[3]; +const float array[3] = float[3](2.5, 7.0, 1.5); `.trim() const ast = parse(glsl) diff --git a/tests/parser.test.ts b/tests/parser.test.ts index de18286..d2ba311 100644 --- a/tests/parser.test.ts +++ b/tests/parser.test.ts @@ -179,8 +179,11 @@ describe('parser', () => { expect(statement.declarations[2].name).toBe('baz') expect(statement.declarations[2].value).toBeInstanceOf(ArrayExpression) expect((statement.declarations[2].value as ArrayExpression).members.length).toBe(0) - expect((statement.declarations[2].value as ArrayExpression).length).toBeInstanceOf(Literal) - expect(((statement.declarations[2].value as ArrayExpression).length as Literal).value).toBe('3') + expect((statement.declarations[2].value as ArrayExpression).type).toBeInstanceOf(Type) + expect((statement.declarations[2].value as ArrayExpression).type.name).toBe('float') + expect((statement.declarations[2].value as ArrayExpression).type.parameters!.length).toBe(1) + expect((statement.declarations[2].value as ArrayExpression).type.parameters![0]).toBeInstanceOf(Literal) + expect(((statement.declarations[2].value as ArrayExpression).type.parameters![0] as Literal).value).toBe('3') } }) @@ -351,4 +354,21 @@ describe('parser', () => { expect(statement.type.name).toBe('float') expect(statement.type.parameters).toBe(null) }) + + it('parses array expressions', () => { + const statement = parse('float[3](2.5, 7.0, 1.5);')[0] as ArrayExpression + expect(statement).toBeInstanceOf(ArrayExpression) + expect(statement.type).toBeInstanceOf(Type) + expect(statement.type.name).toBe('float') + expect(statement.type.parameters!.length).toBe(1) + expect(statement.type.parameters![0]).toBeInstanceOf(Literal) + expect((statement.type.parameters![0] as Literal).value).toBe('3') + expect(statement.members.length).toBe(3) + expect(statement.members[0]).toBeInstanceOf(Literal) + expect((statement.members[0] as Literal).value).toBe('2.5') + expect(statement.members[1]).toBeInstanceOf(Literal) + expect((statement.members[1] as Literal).value).toBe('7.0') + expect(statement.members[2]).toBeInstanceOf(Literal) + expect((statement.members[2] as Literal).value).toBe('1.5') + }) })