diff --git a/Sources/Ast.swift b/Sources/Ast.swift index 7460284..fa26766 100644 --- a/Sources/Ast.swift +++ b/Sources/Ast.swift @@ -41,7 +41,7 @@ struct TupleLiteral: Literal { } struct ObjectLiteral: Literal { - var value: [(Expression, Expression)] + var value: [String: Expression] } struct Set: Statement { @@ -59,14 +59,14 @@ struct Identifier: Expression { var value: String } -protocol Loopvar {} -extension Identifier: Loopvar {} -extension TupleLiteral: Loopvar {} +typealias Loopvar = Expression struct For: Statement { var loopvar: Loopvar var iterable: Expression var body: [Statement] + var defaultBlock: [Statement] + var ifCondition: Expression? } struct MemberExpression: Expression { @@ -124,3 +124,23 @@ struct KeywordArgumentExpression: Expression { struct NullLiteral: Literal { var value: Any? = nil } + +struct SelectExpression: Expression { + var iterable: Expression + var test: Expression +} + +struct Macro: Statement { + var name: Identifier + var args: [Expression] + var body: [Statement] +} + +struct KeywordArgumentsValue: RuntimeValue { + var value: [String: any RuntimeValue] + var builtins: [String: any RuntimeValue] = [:] + + func bool() -> Bool { + !value.isEmpty + } +} diff --git a/Sources/Environment.swift b/Sources/Environment.swift index c845068..af9b811 100644 --- a/Sources/Environment.swift +++ b/Sources/Environment.swift @@ -12,42 +12,39 @@ class Environment { var variables: [String: any RuntimeValue] = [ "namespace": FunctionValue(value: { args, _ in - if args.count == 0 { + if args.isEmpty { return ObjectValue(value: [:]) } - if args.count != 1 || !(args[0] is ObjectValue) { + guard args.count == 1, let objectArg = args[0] as? ObjectValue else { throw JinjaError.runtime("`namespace` expects either zero arguments or a single object argument") } - return args[0] + return objectArg }) ] var tests: [String: (any RuntimeValue...) throws -> Bool] = [ - "boolean": { - args in - args[0] is BooleanValue + "boolean": { args in + return args[0] is BooleanValue }, - "callable": { - args in - args[0] is FunctionValue + "callable": { args in + return args[0] is FunctionValue }, - "odd": { - args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 != 0 + "odd": { args in + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 != 0 } else { - throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of: args.first))") } }, "even": { args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 == 0 + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 == 0 } else { - throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of: args.first))") } }, "false": { args in @@ -62,24 +59,28 @@ class Environment { } return false }, + "string": { args in + return args[0] is StringValue + }, "number": { args in - args[0] is NumericValue + return args[0] is NumericValue }, "integer": { args in if let arg = args[0] as? NumericValue { return arg.value is Int } - return false }, + "mapping": { args in + return args[0] is ObjectValue + }, "iterable": { args in - args[0] is ArrayValue || args[0] is StringValue + return args[0] is ArrayValue || args[0] is StringValue || args[0] is ObjectValue }, "lower": { args in if let arg = args[0] as? StringValue { return arg.value == arg.value.lowercased() } - return false }, "upper": { args in @@ -89,16 +90,47 @@ class Environment { return false }, "none": { args in - args[0] is NullValue + return args[0] is NullValue }, "defined": { args in - !(args[0] is UndefinedValue) + return !(args[0] is UndefinedValue) }, "undefined": { args in - args[0] is UndefinedValue + return args[0] is UndefinedValue }, - "equalto": { _ in - throw JinjaError.syntaxNotSupported("equalto") + "equalto": { args in + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } + }, + "eq": { args in + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } }, ] @@ -107,61 +139,74 @@ class Environment { } func isFunction(_ value: Any, functionType: T.Type) -> Bool { - value is T + return value is T } func convertToRuntimeValues(input: Any) throws -> any RuntimeValue { switch input { case let value as Bool: return BooleanValue(value: value) - case let values as [any Numeric]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) - } - return ArrayValue(value: items) case let value as any Numeric: return NumericValue(value: value) case let value as String: return StringValue(value: value) case let fn as (String) throws -> Void: return FunctionValue { args, _ in - var arg = "" - switch args[0].value { - case let value as String: - arg = value - case let value as Bool: - arg = String(value) - default: - throw JinjaError.runtime("Unknown arg type:\(type(of: args[0].value))") + guard let stringArg = args[0] as? StringValue else { + throw JinjaError.runtime("Argument must be a StringValue") } - - try fn(arg) + try fn(stringArg.value) return NullValue() } case let fn as (Bool) throws -> Void: return FunctionValue { args, _ in - try fn(args[0].value as! Bool) + guard let boolArg = args[0] as? BooleanValue else { + throw JinjaError.runtime("Argument must be a BooleanValue") + } + try fn(boolArg.value) return NullValue() } case let fn as (Int, Int?, Int) -> [Int]: return FunctionValue { args, _ in - let result = fn(args[0].value as! Int, args[1].value as? Int, args[2].value as! Int) + guard let arg0 = args[0] as? NumericValue, let int0 = arg0.value as? Int else { + throw JinjaError.runtime("First argument must be an Int") + } + let int1 = (args[1] as? NumericValue)?.value as? Int + guard let arg2 = args[2] as? NumericValue, let int2 = arg2.value as? Int else { + throw JinjaError.runtime("Third argument must be an Int") + } + let result = fn(int0, int1, int2) return try self.convertToRuntimeValues(input: result) } - case let values as [Any]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) + case let fn as ([Int]) -> [Int]: + return FunctionValue { args, _ in + let intArgs = args.compactMap { ($0 as? NumericValue)?.value as? Int } + guard intArgs.count == args.count else { + throw JinjaError.runtime("Arguments to range must be Ints") + } + let result = fn(intArgs) + return try self.convertToRuntimeValues(input: result) + } + case let fn as (Int, Int?, Int) -> [Int]: + return FunctionValue { args, _ in + guard let arg0 = args[0] as? NumericValue, let int0 = arg0.value as? Int else { + throw JinjaError.runtime("First argument must be an Int") + } + let int1 = (args.count > 1) ? (args[1] as? NumericValue)?.value as? Int : nil + guard let arg2 = args.last as? NumericValue, let int2 = arg2.value as? Int else { + throw JinjaError.runtime("Last argument must be an Int") + } + let result = fn(int0, int1, int2) + return try self.convertToRuntimeValues(input: result) } + case let values as [Any]: + let items = try values.map { try self.convertToRuntimeValues(input: $0) } return ArrayValue(value: items) - case let dictionary as [String: String]: + case let dictionary as [String: Any]: var object: [String: any RuntimeValue] = [:] - for (key, value) in dictionary { - object[key] = StringValue(value: value) + object[key] = try self.convertToRuntimeValues(input: value) } - return ObjectValue(value: object) case is NullValue: return NullValue() @@ -176,12 +221,11 @@ class Environment { } func declareVariable(name: String, value: any RuntimeValue) throws -> any RuntimeValue { - if self.variables.contains(where: { $0.0 == name }) { + if self.variables.keys.contains(name) { throw JinjaError.syntax("Variable already declared: \(name)") } self.variables[name] = value - return value } @@ -191,13 +235,13 @@ class Environment { return value } - func resolve(name: String) throws -> Self { - if self.variables.contains(where: { $0.0 == name }) { + func resolve(name: String) throws -> Environment { + if self.variables.keys.contains(name) { return self } - if let parent { - return try parent.resolve(name: name) as! Self + if let parent = self.parent { + return try parent.resolve(name: name) } throw JinjaError.runtime("Unknown variable: \(name)") @@ -205,11 +249,7 @@ class Environment { func lookupVariable(name: String) -> any RuntimeValue { do { - if let value = try self.resolve(name: name).variables[name] { - return value - } else { - return UndefinedValue() - } + return try self.resolve(name: name).variables[name] ?? UndefinedValue() } catch { return UndefinedValue() } diff --git a/Sources/Lexer.swift b/Sources/Lexer.swift index 1093960..851b0ca 100644 --- a/Sources/Lexer.swift +++ b/Sources/Lexer.swift @@ -50,6 +50,8 @@ enum TokenType: String { case and = "And" case or = "Or" case not = "Not" + case macro = "Macro" + case endMacro = "EndMacro" } struct Token: Equatable { @@ -70,6 +72,8 @@ let keywords: [String: TokenType] = [ "and": .and, "or": .or, "not": .not, + "macro": .macro, + "endmacro": .endMacro, // Literals "true": .booleanLiteral, "false": .booleanLiteral, diff --git a/Sources/Parser.swift b/Sources/Parser.swift index 648a025..2c6bdae 100644 --- a/Sources/Parser.swift +++ b/Sources/Parser.swift @@ -281,11 +281,19 @@ func parse(tokens: [Token]) throws -> Program { func parseTernaryExpression() throws -> Statement { let a = try parseLogicalOrExpression() if typeof(.if) { - current += 1 - let test = try parseLogicalOrExpression() - try expect(type: .else, error: "Expected else token") - let b = try parseLogicalOrExpression() - return If(test: test as! Expression, body: [a], alternate: [b]) + // Ternary expression + current += 1 // consume if + let predicate = try parseLogicalOrExpression() + + if typeof(.else) { + // Ternary expression with else + current += 1 // consume else + let b = try parseLogicalOrExpression() + return If(test: predicate as! Expression, body: [a], alternate: [b]) + } else { + // Select expression on iterable + return SelectExpression(iterable: a as! Expression, test: predicate as! Expression) + } } return a @@ -392,12 +400,20 @@ func parse(tokens: [Token]) throws -> Program { return ArrayLiteral(value: values) case .openCurlyBracket: current += 1 - var values: [(Expression, Expression)] = [] + var values: [String: Expression] = [:] while !typeof(.closeCurlyBracket) { let key = try parseExpression() try expect(type: .colon, error: "Expected colon between key and value in object literal") let value = try parseExpression() - values.append((key as! Expression, value as! Expression)) + + if let key = key as? StringLiteral { + values[key.value] = value as? Expression + } else if let key = key as? Identifier { + values[key.value] = value as? Expression + } else { + throw JinjaError.syntax("Expected string literal or identifier as key in object literal") + } + if typeof(.comma) { current += 1 } @@ -437,9 +453,9 @@ func parse(tokens: [Token]) throws -> Program { func parseForStatement() throws -> Statement { let loopVariable = try parseExpressionSequence(primary: true) - if !(loopVariable is Identifier || loopVariable is TupleLiteral) { + guard let loopvar = loopVariable as? Loopvar else { throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + "Expected identifier/tuple for the loop variable, got \(type(of: loopVariable)) instead" ) } @@ -447,22 +463,62 @@ func parse(tokens: [Token]) throws -> Program { let iterable = try parseExpression() + guard let iterableExpression = iterable as? Expression else { + throw JinjaError.syntax("Expected expression for iterable, got \(type(of: iterable))") + } + + // Handle optional if condition + var ifCondition: Expression? = nil + if typeof(.if) { + current += 1 // Consume 'if' token + ifCondition = try parseExpression() as? Expression + } + try expect(type: .closeStatement, error: "Expected closing statement token") var body: [Statement] = [] - while not(.openStatement, .endFor) { + var defaultBlock: [Statement] = [] + + while not(.openStatement, .endFor) && not(.openStatement, .else) { try body.append(parseAny()) } - if let loopVariable = loopVariable as? Loopvar { - return For(loopvar: loopVariable, iterable: iterable as! Expression, body: body) + if typeof(.openStatement, .else) { + current += 1 // Consume '{%' + try expect(type: .else, error: "Expected else token") + try expect(type: .closeStatement, error: "Expected closing statement token") + + while not(.openStatement, .endFor) { + try defaultBlock.append(parseAny()) + } } - throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + return For( + loopvar: loopvar, + iterable: iterableExpression, + body: body, + defaultBlock: defaultBlock, + ifCondition: ifCondition ) } + func parseMacroStatement() throws -> Macro { + let name = try parsePrimaryExpression() + if !(name is Identifier) { + throw JinjaError.syntax("Expected identifier following macro statement") + } + let args = try parseArgs() + try expect(type: .closeStatement, error: "Expected closing statement token") + + var body: [Statement] = [] + + while not(.openStatement, .endMacro) { + try body.append(parseAny()) + } + + return Macro(name: name as! Identifier, args: args as! [Expression], body: body) + } + func parseJinjaStatement() throws -> Statement { try expect(type: .openStatement, error: "Expected opening statement token") var result: Statement @@ -484,6 +540,12 @@ func parse(tokens: [Token]) throws -> Program { try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endFor, error: "Expected endfor token") try expect(type: .closeStatement, error: "Expected %} token") + case .macro: + current += 1 + result = try parseMacroStatement() + try expect(type: .openStatement, error: "Expected {% token") + try expect(type: .endMacro, error: "Expected endmacro token") + try expect(type: .closeStatement, error: "Expected %} token") default: throw JinjaError.syntax("Unknown statement type: \(tokens[current].type)") } diff --git a/Sources/Runtime.swift b/Sources/Runtime.swift index 73a0d48..2070529 100644 --- a/Sources/Runtime.swift +++ b/Sources/Runtime.swift @@ -8,9 +8,9 @@ import Foundation protocol RuntimeValue { - associatedtype T - var value: T { get set } + associatedtype ValueType + var value: ValueType { get } var builtins: [String: any RuntimeValue] { get set } func bool() -> Bool @@ -21,7 +21,12 @@ struct NumericValue: RuntimeValue { var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { - self.value as? Int != 0 + if let intValue = self.value as? Int { + return intValue != 0 + } else if let doubleValue = self.value as? Double { + return doubleValue != 0.0 + } + return false } } @@ -35,7 +40,7 @@ struct BooleanValue: RuntimeValue { } struct NullValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -44,7 +49,7 @@ struct NullValue: RuntimeValue { } struct UndefinedValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -69,11 +74,18 @@ struct ArrayValue: RuntimeValue { } struct TupleValue: RuntimeValue { - var value: ArrayValue + var value: [any RuntimeValue] var builtins: [String: any RuntimeValue] = [:] + init(value: [any RuntimeValue]) { + self.value = value + self.builtins["length"] = FunctionValue(value: { _, _ in + NumericValue(value: value.count) + }) + } + func bool() -> Bool { - self.value.bool() + !self.value.isEmpty } } @@ -85,29 +97,22 @@ struct ObjectValue: RuntimeValue { self.value = value self.builtins = [ "get": FunctionValue(value: { args, _ in - if let key = args[0] as? StringValue { - if let value = value.first(where: { $0.0 == key.value }) { - return value as! (any RuntimeValue) - } else if args.count > 1 { - return args[1] - } else { - return NullValue() - } + guard let key = args[0] as? StringValue else { + throw JinjaError.runtime("Object key must be a string: got \(type(of: args[0]))") + } + if let value = value[key.value] { + return value + } else if args.count > 1 { + return args[1] } else { - throw JinjaError.runtime("Object key must be a string: got \(type(of:args[0]))") + return NullValue() } }), "items": FunctionValue(value: { _, _ in - var items: [ArrayValue] = [] - for (k, v) in value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) + let items = value.map { (key, value) in + ArrayValue(value: [StringValue(value: key), value]) } - return items as! (any RuntimeValue) + return ArrayValue(value: items) }), ] } @@ -146,12 +151,18 @@ struct StringValue: RuntimeValue { }), "title": FunctionValue(value: { _, _ in - StringValue(value: value.capitalized) + StringValue(value: value.titleCase()) }), "length": FunctionValue(value: { _, _ in NumericValue(value: value.count) }), + "rstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "\\s+$", with: "", options: .regularExpression)) + }), + "lstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "^\\s+", with: "", options: .regularExpression)) + }), ] } @@ -177,17 +188,14 @@ struct Interpreter { let lastEvaluated = try self.evaluate(statement: statement, environment: environment) if !(lastEvaluated is NullValue), !(lastEvaluated is UndefinedValue) { - if let value = lastEvaluated.value as? String { - result += value + if let stringValue = lastEvaluated as? StringValue { + result += stringValue.value + } else if let numericValue = lastEvaluated as? NumericValue { + result += String(describing: numericValue.value) + } else if let booleanValue = lastEvaluated as? BooleanValue { + result += String(booleanValue.value) } else { - switch lastEvaluated.value { - case let value as Int: - result += String(value) - case let value as String: - result += value - default: - throw JinjaError.runtime("Unknown value type:\(type(of: lastEvaluated.value))") - } + throw JinjaError.runtime("Cannot convert to string: \(type(of: lastEvaluated))") } } } @@ -234,44 +242,44 @@ struct Interpreter { } func evaluateFor(node: For, environment: Environment) throws -> any RuntimeValue { + // Scope for the for loop let scope = Environment(parent: environment) - let iterable = try self.evaluate(statement: node.iterable, environment: scope) - var result = "" - if let iterable = iterable as? ArrayValue { - for i in 0 ..< iterable.value.count { - let loop: [String: any RuntimeValue] = [ - "index": NumericValue(value: i + 1), - "index0": NumericValue(value: i), - "revindex": NumericValue(value: iterable.value.count - i), - "revindex0": NumericValue(value: iterable.value.count - i - 1), - "first": BooleanValue(value: i == 0), - "last": BooleanValue(value: i == iterable.value.count - 1), - "length": NumericValue(value: iterable.value.count), - "previtem": i > 0 ? iterable.value[i - 1] : UndefinedValue(), - "nextitem": i < iterable.value.count - 1 ? iterable.value[i + 1] : UndefinedValue(), - ] - - try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) - - let current = iterable.value[i] - - if let identifier = node.loopvar as? Identifier { - try scope.setVariable(name: identifier.value, value: current) - } else { - } + var test: Expression? + var iterable: any RuntimeValue + if let selectExpression = node.iterable as? SelectExpression { + iterable = try self.evaluate(statement: selectExpression.iterable, environment: scope) + test = selectExpression.test + } else { + iterable = try self.evaluate(statement: node.iterable, environment: scope) + } + + guard let arrayIterable = iterable as? ArrayValue else { + throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of: iterable))") + } + + var items: [any RuntimeValue] = [] + var scopeUpdateFunctions: [((Environment) throws -> Void)] = [] + + for i in 0 ..< arrayIterable.value.count { + let loopScope = Environment(parent: scope) - switch node.loopvar { - case let identifier as Identifier: + let current = arrayIterable.value[i] + + var scopeUpdateFunction: ((Environment) throws -> Void) + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in try scope.setVariable(name: identifier.value, value: current) - case let tupleLiteral as TupleLiteral: - if let current = current as? ArrayValue { - if tupleLiteral.value.count != current.value.count { - throw JinjaError.runtime( - "Too \(tupleLiteral.value.count > current.value.count ? "few" : "many") items to unpack" - ) - } + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + if let current = current as? ArrayValue { + if tupleLiteral.value.count != current.value.count { + throw JinjaError.runtime( + "Too \(tupleLiteral.value.count > current.value.count ? "few" : "many") items to unpack" + ) + } + scopeUpdateFunction = { scope in for j in 0 ..< tupleLiteral.value.count { if let identifier = tupleLiteral.value[j] as? Identifier { try scope.setVariable(name: identifier.value, value: current.value[j]) @@ -281,18 +289,63 @@ struct Interpreter { ) } } - } else { - throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))") } - default: - throw JinjaError.syntaxNotSupported(String(describing: node.loopvar)) + } else { + throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))") + } + } else { + throw JinjaError.syntaxNotSupported(String(describing: node.loopvar)) + } + + if let ifCondition = node.ifCondition { + try scopeUpdateFunction(loopScope) // Update scope before evaluating the condition + + let ifConditionResult = try self.evaluate(statement: ifCondition, environment: loopScope) + if !ifConditionResult.bool() { + continue // Skip to the next iteration if the condition is false } + } else if let test = test { + try scopeUpdateFunction(loopScope) - let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) - result += evaluated.value + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + continue + } } - } else { - throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of:iterable))") + + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + } + + var result = "" + var noIteration = true + + for i in 0 ..< items.count { + let loop: [String: any RuntimeValue] = [ + "index": NumericValue(value: i + 1), + "index0": NumericValue(value: i), + "revindex": NumericValue(value: items.count - i), + "revindex0": NumericValue(value: items.count - i - 1), + "first": BooleanValue(value: i == 0), + "last": BooleanValue(value: i == items.count - 1), + "length": NumericValue(value: items.count), + "previtem": i > 0 ? items[i - 1] : UndefinedValue(), + "nextitem": i < items.count - 1 ? items[i + 1] : UndefinedValue(), + ] + + try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) + + try scopeUpdateFunctions[i](scope) + + let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) + result += evaluated.value + + noIteration = false + } + + if noIteration { + let defaultEvaluated = try self.evaluateBlock(statements: node.defaultBlock, environment: scope) + result += defaultEvaluated.value } return StringValue(value: result) @@ -309,24 +362,61 @@ struct Interpreter { let right = try self.evaluate(statement: node.right, environment: environment) + // == if node.operation.value == "==" { - switch left.value { - case let value as String: - return BooleanValue(value: value == right.value as! String) - case let value as Int: - return BooleanValue(value: value == right.value as! Int) - case let value as Bool: - return BooleanValue(value: value == right.value as! Bool) - default: - throw JinjaError.runtime( - "Unknown left value type:\(type(of: left.value)), right value type:\(type(of: right.value))" - ) + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value == right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt == rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble == rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) == rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble == Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for equality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value == right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: true) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: true) + } else if type(of: left) == type(of: right) { + return BooleanValue(value: false) + } else { + return BooleanValue(value: false) } - } else if node.operation.value == "!=" { - if type(of: left) != type(of: right) { + } + + // != + if node.operation.value == "!=" { + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value != right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt != rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble != rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) != rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble != Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for inequality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value != right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: false) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: false) + } else if type(of: left) == type(of: right) { return BooleanValue(value: true) } else { - return BooleanValue(value: left.value as! AnyHashable != right.value as! AnyHashable) + return BooleanValue(value: true) } } @@ -336,92 +426,228 @@ struct Interpreter { throw JinjaError.runtime("Cannot perform operation on null values") } else if let left = left as? NumericValue, let right = right as? NumericValue { switch node.operation.value { - case "+": throw JinjaError.syntaxNotSupported("+") - case "-": throw JinjaError.syntaxNotSupported("-") - case "*": throw JinjaError.syntaxNotSupported("*") - case "/": throw JinjaError.syntaxNotSupported("/") + case "+": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt + rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble + rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) + rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble + Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for addition") + } + case "-": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt - rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble - rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) - rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble - Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for subtraction") + } + case "*": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt * rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble * rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) * rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble * Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for multiplication") + } + case "/": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt / rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble / rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) / rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble / Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for division") + } case "%": - switch left.value { - case is Int: - return NumericValue(value: left.value as! Int % (right.value as! Int)) - default: - throw JinjaError.runtime("Unknown value type:\(type(of: left.value))") + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt % rightInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for modulus") + } + case "<": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt < rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble < rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) < rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble < Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than comparison") + } + case ">": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt > rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble > rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) > rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble > Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than comparison") + } + case ">=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt >= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble >= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) >= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble >= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than or equal to comparison") + } + case "<=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt <= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble <= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) <= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble <= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than or equal to comparison") } - case "<": throw JinjaError.syntaxNotSupported("<") - case ">": throw JinjaError.syntaxNotSupported(">") - case ">=": throw JinjaError.syntaxNotSupported(">=") - case "<=": throw JinjaError.syntaxNotSupported("<=") default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if left is ArrayValue && right is ArrayValue { + } else if let left = left as? ArrayValue, let right = right as? ArrayValue { switch node.operation.value { - case "+": break + case "+": + return ArrayValue(value: left.value + right.value) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if right is ArrayValue { - throw JinjaError.syntaxNotSupported("right is ArrayValue") - } - - if left is StringValue || right is StringValue { - switch node.operation.value { - case "+": - var rightValue = "" - var leftValue = "" - switch right.value { - case let value as String: - rightValue = value - case let value as Int: - rightValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown right value type:\(type(of: right.value))") + } else if let right = right as? ArrayValue { + let member: Bool + if let left = left as? StringValue { + member = right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false } - - switch left.value { - case let value as String: - leftValue = value - case let value as Int: - leftValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown left value type:\(type(of: left.value))") + } else if let left = left as? NumericValue { + member = right.value.contains { + if let item = $0 as? NumericValue { + return item.value as! Int == left.value as! Int + } + return false } - - return StringValue(value: leftValue + rightValue) - default: - break + } else if let left = left as? BooleanValue { + member = right.value.contains { + if let item = $0 as? BooleanValue { + return item.value == left.value + } + return false + } + } else { + throw JinjaError.runtime("Unsupported left type for 'in'/'not in' operation with ArrayValue") } - } - if let left = left as? StringValue, let right = right as? StringValue { switch node.operation.value { case "in": - return BooleanValue(value: right.value.contains(left.value)) + return BooleanValue(value: member) case "not in": - return BooleanValue(value: !right.value.contains(left.value)) + return BooleanValue(value: !member) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } } - if left is StringValue, right is ObjectValue { + if let left = left as? StringValue { switch node.operation.value { + case "+": + let rightValue: String + if let rightString = right as? StringValue { + rightValue = rightString.value + } else if let rightNumeric = right as? NumericValue { + rightValue = String(describing: rightNumeric.value) + } else if let rightBoolean = right as? BooleanValue { + rightValue = String(rightBoolean.value) + } else { + throw JinjaError.runtime("Unsupported right operand type for string concatenation") + } + return StringValue(value: left.value + rightValue) case "in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime("Right operand of 'in' must be a StringValue, ArrayValue, or ObjectValue") } case "not in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: !rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: !right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: !right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: !right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime( + "Right operand of 'not in' must be a StringValue, ArrayValue, or ObjectValue" + ) + } + default: + break + } + } else if let right = right as? StringValue { + if node.operation.value == "+" { + if let leftString = left as? StringValue { + return StringValue(value: leftString.value + right.value) + } else if let leftNumeric = left as? NumericValue { + return StringValue(value: String(describing: leftNumeric.value) + right.value) + } else if let leftBoolean = left as? BooleanValue { + return StringValue(value: String(leftBoolean.value) + right.value) + } else { + throw JinjaError.runtime("Unsupported left operand type for string concatenation") } + } + } + + if let left = left as? StringValue, let right = right as? ObjectValue { + switch node.operation.value { + case "in": + return BooleanValue(value: right.value.keys.contains(left.value)) + case "not in": + return BooleanValue(value: !right.value.keys.contains(left.value)) default: throw JinjaError.runtime( "Unsupported operation '\(node.operation.value)' between StringValue and ObjectValue" @@ -463,19 +689,19 @@ struct Interpreter { return ArrayValue( value: slice( object.value, - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int ) ) } else if let object = object as? StringValue { return StringValue( value: slice( - Array(arrayLiteral: object.value), - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int - ).joined() + Array(object.value), + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int + ).map { String($0) }.joined() ) } @@ -503,29 +729,36 @@ struct Interpreter { } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } - } else if object is ArrayValue || object is StringValue { + } else if let object = object as? ArrayValue { if let property = property as? NumericValue { - if let object = object as? ArrayValue { - let index = property.value as! Int - if index >= 0 { - value = object.value[index] - } else { - value = object.value[object.value.count + index] - } - } else if let object = object as? StringValue { - let index = object.value.index(object.value.startIndex, offsetBy: property.value as! Int) - value = StringValue(value: String(object.value[index])) + let index = property.value as! Int + if index >= 0 { + value = object.value[index] + } else { + value = object.value[object.value.count + index] } } else if let property = property as? StringValue { value = object.builtins[property.value] } else { throw JinjaError.runtime( - "Cannot access property with non-string/non-number: got \(type(of:property))" + "Cannot access property with non-string/non-number: got \(type(of: property))" + ) + } + } else if let object = object as? StringValue { + if let property = property as? NumericValue { + let index = property.value as! Int + let strIndex = object.value.index(object.value.startIndex, offsetBy: index) + value = StringValue(value: String(object.value[strIndex])) + } else if let property = property as? StringValue { + value = object.builtins[property.value] + } else { + throw JinjaError.runtime( + "Cannot access property with non-string/non-number: got \(type(of: property))" ) } } else { if let property = property as? StringValue { - value = object.builtins[property.value]! + value = object.builtins[property.value] } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } @@ -561,7 +794,7 @@ struct Interpreter { } } - if kwargs.count > 0 { + if !kwargs.isEmpty { args.append(ObjectValue(value: kwargs)) } @@ -575,9 +808,13 @@ struct Interpreter { } func evaluateFilterExpression(node: FilterExpression, environment: Environment) throws -> any RuntimeValue { - let operand = try evaluate(statement: node.operand, environment: environment) + let operand = try self.evaluate(statement: node.operand, environment: environment) if let identifier = node.filter as? Identifier { + if identifier.value == "tojson" { + return try StringValue(value: toJSON(operand)) + } + if let arrayValue = operand as? ArrayValue { switch identifier.value { case "list": @@ -591,7 +828,32 @@ struct Interpreter { case "reverse": return ArrayValue(value: arrayValue.value.reversed()) case "sort": - throw JinjaError.todo("TODO: ArrayValue filter sort") + return ArrayValue( + value: try arrayValue.value.sorted { + // No need to cast to AnyComparable here + if let a = $0 as? NumericValue, let b = $1 as? NumericValue { + if let aInt = a.value as? Int, let bInt = b.value as? Int { + return aInt < bInt + } else if let aDouble = a.value as? Double, let bDouble = b.value as? Double { + return aDouble < bDouble + } else if let aInt = a.value as? Int, let bDouble = b.value as? Double { + return Double(aInt) < bDouble + } else if let aDouble = a.value as? Double, let bInt = b.value as? Int { + return aDouble < Double(bInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for comparison") + } + } else if let a = $0 as? StringValue, let b = $1 as? StringValue { + return a.value < b.value + } else { + throw JinjaError.runtime( + "Cannot compare values of different types or non-comparable types" + ) + } + } + ) + case "map": + throw JinjaError.todo("TODO: ArrayValue filter map") default: throw JinjaError.runtime("Unknown ArrayValue filter: \(identifier.value)") } @@ -604,34 +866,38 @@ struct Interpreter { case "lower": return StringValue(value: stringValue.value.lowercased()) case "title": - return StringValue(value: stringValue.value.capitalized) + return StringValue(value: stringValue.value.titleCase()) case "capitalize": - return StringValue(value: stringValue.value.capitalized) + return StringValue(value: stringValue.value.prefix(1).uppercased() + stringValue.value.dropFirst()) case "trim": return StringValue(value: stringValue.value.trimmingCharacters(in: .whitespacesAndNewlines)) + case "indent": + return StringValue(value: stringValue.value.indent(4)) + case "string": + return stringValue default: throw JinjaError.runtime("Unknown StringValue filter: \(identifier.value)") } } else if let numericValue = operand as? NumericValue { switch identifier.value { case "abs": - return NumericValue(value: abs(numericValue.value as! Int32)) + if let intValue = numericValue.value as? Int { + return NumericValue(value: abs(intValue)) + } else if let doubleValue = numericValue.value as? Double { + return NumericValue(value: abs(doubleValue)) + } else { + throw JinjaError.runtime("Unsupported numeric type for abs filter") + } default: throw JinjaError.runtime("Unknown NumericValue filter: \(identifier.value)") } } else if let objectValue = operand as? ObjectValue { switch identifier.value { case "items": - var items: [ArrayValue] = [] - for (k, v) in objectValue.value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) + let items: [ArrayValue] = objectValue.value.map { (key, value) in + return ArrayValue(value: [StringValue(value: key), value]) } - return items as! (any RuntimeValue) + return ArrayValue(value: items) case "length": return NumericValue(value: objectValue.value.count) default: @@ -639,7 +905,133 @@ struct Interpreter { } } - throw JinjaError.runtime("Cannot apply filter \(operand.value) to type: \(type(of:operand))") + throw JinjaError.runtime("Cannot apply filter \(identifier.value) to type: \(type(of: operand))") + } else if let callExpression = node.filter as? CallExpression { + if let identifier = callExpression.callee as? Identifier { + let filterName = identifier.value + + if filterName == "tojson" { + let args = try self.evaluateArguments(args: callExpression.args, environment: environment) + let indent = args.1["indent"] ?? NullValue() + + if !(indent is NumericValue || indent is NullValue) { + throw JinjaError.runtime("If set, indent must be a number") + } + + return try StringValue(value: toJSON(operand, indent: (indent as? NumericValue)?.value as? Int)) + } + + if let arrayValue = operand as? ArrayValue { + switch filterName { + case "selectattr", "rejectattr": + let select = filterName == "selectattr" + + if arrayValue.value.contains(where: { !($0 is ObjectValue) }) { + throw JinjaError.runtime("`\(filterName)` can only be applied to array of objects") + } + + if callExpression.args.contains(where: { !($0 is StringLiteral) }) { + throw JinjaError.runtime("arguments of `\(filterName)` must be strings") + } + + let args = try callExpression.args.map { arg -> StringValue in + let evaluatedArg = try self.evaluate(statement: arg, environment: environment) + guard let stringValue = evaluatedArg as? StringValue else { + throw JinjaError.runtime("Arguments of `\(filterName)` must be strings") + } + return stringValue + } + + let attr = args[0] + let testName = args.count > 1 ? args[1] : nil + let value = args.count > 2 ? args[2] : nil + + var testFunction: ((any RuntimeValue, StringValue?) throws -> Bool) + + if let testName = testName { + guard let test = environment.tests[testName.value] else { + throw JinjaError.runtime("Unknown test: \(testName.value)") + } + testFunction = { a, b in + try test(a, b ?? UndefinedValue()) + } + } else { + testFunction = { a, _ in + a.bool() + } + } + + let filtered = (arrayValue.value as! [ObjectValue]).filter { item in + let a = item.value[attr.value] + let result = a != nil ? try! testFunction(a!, value) : false + return select ? result : !result + } + + return ArrayValue(value: filtered) + case "map": + let evaluatedArgs = try self.evaluateArguments( + args: callExpression.args, + environment: environment + ) + let kwargs = evaluatedArgs.1 + + if let attribute = kwargs["attribute"] as? StringValue { + let defaultValue = kwargs["default"] + + let mapped = try arrayValue.value.map { item -> Any in + guard let objectValue = item as? ObjectValue else { + throw JinjaError.runtime("Items in map must be objects") + } + return objectValue.value[attribute.value] ?? defaultValue ?? UndefinedValue() + } + + return ArrayValue(value: mapped.map { $0 as! (any RuntimeValue) }) + } else { + // TODO: Implement map filter without attribute argument + // This will likely involve applying a filter function to each element. + throw JinjaError.runtime("`map` filter without `attribute` is not yet supported.") + } + default: + throw JinjaError.runtime("Unknown ArrayValue filter: \(filterName)") + } + } else if let stringValue = operand as? StringValue { + switch filterName { + case "indent": + let args = try self.evaluateArguments(args: callExpression.args, environment: environment) + let positionalArgs = args.0 + let kwargs = args.1 + + let width = positionalArgs.first ?? kwargs["width"] ?? NumericValue(value: 4) + + if !(width is NumericValue) { + throw JinjaError.runtime("width must be a number") + } + + let first = + positionalArgs.count > 1 ? positionalArgs[1] : kwargs["first"] ?? BooleanValue(value: false) + let blank = + positionalArgs.count > 2 ? positionalArgs[2] : kwargs["blank"] ?? BooleanValue(value: false) + + guard let widthInt = (width as? NumericValue)?.value as? Int else { + throw JinjaError.runtime("width must be an integer") + } + + return StringValue( + value: stringValue.value.indent( + widthInt, + first: first.bool(), + blank: blank.bool() + ) + ) + default: + throw JinjaError.runtime("Unknown StringValue filter: \(filterName)") + } + } else { + throw JinjaError.runtime("Cannot apply filter '\(filterName)' to type: \(type(of: operand))") + } + } else { + throw JinjaError.runtime("Unknown filter: \(callExpression.callee)") + } } throw JinjaError.runtime("Unknown filter: \(node.filter)") @@ -656,6 +1048,70 @@ struct Interpreter { } } + func evaluateMacro(node: Macro, environment: Environment) throws -> NullValue { + try environment.setVariable( + name: node.name.value, + value: FunctionValue(value: { args, scope in + let macroScope = Environment(parent: scope) + + var args = args + var kwargs: [String: any RuntimeValue] = [:] + + if let lastArg = args.last, let objectValue = lastArg as? ObjectValue { + kwargs = objectValue.value + args.removeLast() + } + + for i in 0 ..< node.args.count { + let nodeArg = node.args[i] + let passedArg = args.count > i ? args[i] : nil + + if let identifier = nodeArg as? Identifier { + if passedArg == nil { + throw JinjaError.runtime("Missing positional argument: \(identifier.value)") + } + try macroScope.setVariable(name: identifier.value, value: passedArg!) + } else if let kwarg = nodeArg as? KeywordArgumentExpression { + let value = + try passedArg ?? kwargs[kwarg.key.value] + ?? (try self.evaluate(statement: kwarg.value, environment: macroScope)) + try macroScope.setVariable(name: kwarg.key.value, value: value) + } else { + throw JinjaError.runtime("Unknown argument type: \(type(of: nodeArg))") + } + } + + return try self.evaluateBlock(statements: node.body, environment: macroScope) + }) + ) + + return NullValue() + } + + func evaluateArguments( + args: [Expression], + environment: Environment + ) throws -> ([any RuntimeValue], [String: any RuntimeValue]) { + var positionalArguments: [any RuntimeValue] = [] + var keywordArguments: [String: any RuntimeValue] = [:] + + for argument in args { + if let keywordArgument = argument as? KeywordArgumentExpression { + keywordArguments[keywordArgument.key.value] = try self.evaluate( + statement: keywordArgument.value, + environment: environment + ) + } else { + if !keywordArguments.isEmpty { + throw JinjaError.runtime("Positional arguments must come before keyword arguments") + } + positionalArguments.append(try self.evaluate(statement: argument, environment: environment)) + } + } + + return (positionalArguments, keywordArguments) + } + func evaluate(statement: Statement?, environment: Environment) throws -> any RuntimeValue { if let statement { switch statement { @@ -687,6 +1143,22 @@ struct Interpreter { return try self.evaluateFilterExpression(node: statement, environment: environment) case let statement as TestExpression: return try self.evaluateTestExpression(node: statement, environment: environment) + case let statement as ArrayLiteral: + return ArrayValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as TupleLiteral: + return TupleValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as ObjectLiteral: + var mapping: [String: any RuntimeValue] = [:] + for (key, value) in statement.value { + mapping[key] = try self.evaluate(statement: value, environment: environment) + } + return ObjectValue(value: mapping) + case let statement as Macro: + return try self.evaluateMacro(node: statement, environment: environment) case is NullLiteral: return NullValue() default: diff --git a/Sources/Utilities.swift b/Sources/Utilities.swift index c01870b..adef924 100644 --- a/Sources/Utilities.swift +++ b/Sources/Utilities.swift @@ -38,3 +38,89 @@ func slice(_ array: [T], start: Int? = nil, stop: Int? = nil, step: Int? = 1) return slicedArray } + +func toJSON(_ input: any RuntimeValue, indent: Int? = nil, depth: Int = 0) throws -> String { + let currentDepth = depth + + switch input { + case is NullValue, is UndefinedValue: + return "null" + + case let value as NumericValue: + return String(describing: value.value) + + case let value as StringValue: + return "\"\(value.value)\"" // Directly wrap string in quotes + + case let value as BooleanValue: + return value.value ? "true" : "false" + + case let arr as ArrayValue: + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = "\n" + String(repeating: indentValue, count: currentDepth) + let childrenPadding = basePadding + indentValue // Depth + 1 + + let core = try arr.value.map { try toJSON($0, indent: indent, depth: currentDepth + 1) } + + if indent != nil { + return "[\(childrenPadding)\(core.joined(separator: ",\(childrenPadding)"))\(basePadding)]" + } else { + return "[\(core.joined(separator: ", "))]" + } + + case let obj as ObjectValue: + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = "\n" + String(repeating: indentValue, count: currentDepth) + let childrenPadding = basePadding + indentValue // Depth + 1 + + let core = try obj.value.map { key, value in + let v = "\"\(key)\": \(try toJSON(value, indent: indent, depth: currentDepth + 1))" + return indent != nil ? "\(childrenPadding)\(v)" : v + } + + if indent != nil { + return "{\(core.joined(separator: ","))\(basePadding)}" + } else { + return "{\(core.joined(separator: ", "))}" + } + + default: + throw JinjaError.runtime("Cannot convert to JSON: \(type(of: input))") + } +} + +// Helper function to convert values to JSON strings +private func jsonString(_ value: Any) throws -> String { + let data = try JSONSerialization.data(withJSONObject: value) + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert value to JSON string") + } + return string +} + +extension String { + func titleCase() -> String { + return self.components(separatedBy: .whitespacesAndNewlines) + .map { word in + guard let firstChar = word.first else { return "" } + return String(firstChar).uppercased() + word.dropFirst() + } + .joined(separator: " ") + } + + func indent(_ width: Int, first: Bool = false, blank: Bool = false) -> String { + let indentString = String(repeating: " ", count: width) + return self.components(separatedBy: .newlines) + .enumerated() + .map { (index, line) in + if line.isEmpty && !blank { + return line + } + if index == 0 && !first { + return line + } + return indentString + line + } + .joined(separator: "\n") + } +} diff --git a/Tests/ChatTemplateTests.swift b/Tests/ChatTemplateTests.swift index 4b9ab6b..6d1f063 100644 --- a/Tests/ChatTemplateTests.swift +++ b/Tests/ChatTemplateTests.swift @@ -9,6 +9,11 @@ import XCTest @testable import Jinja +let llama3_2visionChatTemplate = + "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == \"\" %}\n {{- raise_exception(\"Prompting with images is incompatible with system messages.\") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" +let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + let messages: [[String: String]] = [ [ "role": "user", @@ -34,6 +39,7 @@ let messagesWithSystem: [[String: String]] = final class ChatTemplateTests: XCTestCase { struct Test { + let name: String let chatTemplate: String let data: [String: Any] let target: String @@ -41,6 +47,7 @@ final class ChatTemplateTests: XCTestCase { let defaultTemplates: [Test] = [ Test( + name: "Generic chat template with messages", chatTemplate: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ @@ -52,6 +59,7 @@ final class ChatTemplateTests: XCTestCase { ), // facebook/blenderbot-400M-distill Test( + name: "facebook/blenderbot-400M-distill", chatTemplate: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", data: [ @@ -63,6 +71,7 @@ final class ChatTemplateTests: XCTestCase { ), // facebook/blenderbot_small-90M Test( + name: "facebook/blenderbot_small-90M", chatTemplate: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", data: [ @@ -74,6 +83,7 @@ final class ChatTemplateTests: XCTestCase { ), // bigscience/bloom Test( + name: "bigscience/bloom", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ "messages": messages, @@ -84,6 +94,7 @@ final class ChatTemplateTests: XCTestCase { ), // EleutherAI/gpt-neox-20b Test( + name: "EleutherAI/gpt-neox-20b", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ "messages": messages, @@ -92,8 +103,9 @@ final class ChatTemplateTests: XCTestCase { target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" ), - // gpt2 + // GPT-2 Test( + name: "GPT-2", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ "messages": messages, @@ -104,6 +116,7 @@ final class ChatTemplateTests: XCTestCase { ), // hf-internal-testing/llama-tokenizer Test( + name: "hf-internal-testing/llama-tokenizer 1", chatTemplate: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", data: [ @@ -117,6 +130,7 @@ final class ChatTemplateTests: XCTestCase { ), // hf-internal-testing/llama-tokenizer Test( + name: "hf-internal-testing/llama-tokenizer 2", chatTemplate: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", data: [ @@ -130,6 +144,7 @@ final class ChatTemplateTests: XCTestCase { ), // hf-internal-testing/llama-tokenizer Test( + name: "hf-internal-testing/llama-tokenizer 3", chatTemplate: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", data: [ @@ -156,6 +171,7 @@ final class ChatTemplateTests: XCTestCase { ), // openai/whisper-large-v3 Test( + name: "openai/whisper-large-v3", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ "messages": messages, @@ -166,6 +182,7 @@ final class ChatTemplateTests: XCTestCase { ), // Qwen/Qwen1.5-1.8B-Chat Test( + name: "Qwen/Qwen1.5-1.8B-Chat", chatTemplate: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ @@ -177,6 +194,7 @@ final class ChatTemplateTests: XCTestCase { ), // Qwen/Qwen1.5-1.8B-Chat Test( + name: "Qwen/Qwen1.5-1.8B-Chat 2", chatTemplate: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ @@ -188,6 +206,7 @@ final class ChatTemplateTests: XCTestCase { ), // Qwen/Qwen1.5-1.8B-Chat Test( + name: "Qwen/Qwen1.5-1.8B-Chat 3", chatTemplate: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ @@ -198,6 +217,7 @@ final class ChatTemplateTests: XCTestCase { ), // THUDM/chatglm3-6b Test( + name: "THUDM/chatglm3-6b", chatTemplate: "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", data: [ @@ -208,6 +228,7 @@ final class ChatTemplateTests: XCTestCase { ), // google/gemma-2b-it Test( + name: "google/gemma-2b-it", chatTemplate: "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", data: [ @@ -218,6 +239,7 @@ final class ChatTemplateTests: XCTestCase { ), // Qwen/Qwen2.5-0.5B-Instruct Test( + name: "Qwen/Qwen2.5-0.5B-Instruct", chatTemplate: "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", data: [ @@ -226,13 +248,254 @@ final class ChatTemplateTests: XCTestCase { target: "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" ), + // Llama-3.2-11B-Vision-Instruct: text chat only + Test( + name: "Llama-3.2-11B-Vision-Instruct: text chat only", + chatTemplate: llama3_2visionChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "Hello, how are you?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "assistant", + "content": [ + [ + "type": "text", + "text": "I'm doing great. How can I help you today?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "I'd like to show off how chat templating works!", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "date_string": "26 Jul 2024" as Any, + "tools_in_user_message": true as Any, + "system_message": "You are a helpful assistant." as Any, + "add_generation_prompt": true as Any, + ], + target: + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + // Llama-3.2-11B-Vision-Instruct: with images + Test( + name: "Llama-3.2-11B-Vision-Instruct: with images", + chatTemplate: llama3_2visionChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: Any], + [ + "type": "image", + "image": "base64_encoded_image_data", + ] as [String: Any], + ] as [[String: Any]], + ] as [String: Any] + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + ], + target: + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat's in this image?<|image|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + // Qwen2-VL text only + Test( + name: "Qwen2-VL-7B-Instruct: text only", + chatTemplate: qwen2VLChatTemplate, + data: [ + "messages": messages, + "add_generation_prompt": true, + ], + target: """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + Hello, how are you?<|im_end|> + <|im_start|>assistant + I'm doing great. How can I help you today?<|im_end|> + <|im_start|>user + I'd like to show off how chat templating works!<|im_end|> + <|im_start|>assistant + + """ + ), + // Qwen2-VL with images + Test( + name: "Qwen2-VL-7B-Instruct: with images", + chatTemplate: qwen2VLChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: String], + [ + "type": "image", + "image_url": "example.jpg", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ], + target: """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's in this image?Picture 0: <|vision_start|><|image_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + ), + // Qwen2-VL with video + Test( + name: "Qwen2-VL-7B-Instruct: with video", + chatTemplate: qwen2VLChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's happening in this video?", + ] as [String: String], + [ + "type": "video", + "video_url": "example.mp4", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ], + target: """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's happening in this video?Video 0: <|vision_start|><|video_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + ), ] func testDefaultTemplates() throws { for test in defaultTemplates { + print("Testing \(test.name)") let template = try Template(test.chatTemplate) let result = try template.render(test.data) + print(result) XCTAssertEqual(result.debugDescription, test.target.debugDescription) } } + + // TODO: Get testLlama32ToolCalls working + + // func testLlama32ToolCalls() throws { + // let tools = [ + // [ + // "name": "get_current_weather", + // "description": "Get the current weather in a given location", + // "parameters": [ + // "type": "object", + // "properties": [ + // "location": [ + // "type": "string", + // "description": "The city and state, e.g. San Francisco, CA" + // ], + // "unit": [ + // "type": "string", + // "enum": ["celsius", "fahrenheit"] + // ] + // ], + // "required": ["location"] + // ] + // ] + // ] + // + // let messages: [[String: Any]] = [ + // [ + // "role": "user", + // "content": "What's the weather like in San Francisco?" + // ], + // [ + // "role": "assistant", + // "tool_calls": [ + // [ + // "function": [ + // "name": "get_current_weather", + // "arguments": "{\"location\": \"San Francisco, CA\", \"unit\": \"celsius\"}" + // ] + // ] + // ] + // ], + // [ + // "role": "tool", + // "content": "{\"temperature\": 22, \"unit\": \"celsius\", \"description\": \"Sunny\"}" + // ], + // [ + // "role": "assistant", + // "content": "The weather in San Francisco is sunny with a temperature of 22°C." + // ] + // ] + // + // let template = try Template(llama3_2visionChatTemplate) + // let result = try template.render([ + // "messages": messages, + // "tools": tools, + // "bos_token": "", + // "date_string": "26 Jul 2024", + // "add_generation_prompt": true + // ]) + // + // print(result) // Debugging for comparison with expected + // + // // TODO: Replace with printed result if it works + // let expected = """ + // + // <|start_header_id|>system<|end_header_id|> + // + // Environment: ipython + // Cutting Knowledge Date: December 2023 + // Today Date: 26 Jul 2024 + // + // <|eot_id|><|start_header_id|>user<|end_header_id|> + // + // What's the weather like in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + // + // {"name": "get_current_weather", "parameters": {"location": "San Francisco, CA", "unit": "celsius"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|> + // + // {"temperature": 22, "unit": "celsius", "description": "Sunny"}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + // + // The weather in San Francisco is sunny with a temperature of 22°C.<|eot_id|><|start_header_id|>assistant<|end_header_id|> + // + // """ + // + // XCTAssertEqual(result, expected) + // } } diff --git a/Tests/InterpreterTests.swift b/Tests/InterpreterTests.swift index d402f84..631d2e6 100644 --- a/Tests/InterpreterTests.swift +++ b/Tests/InterpreterTests.swift @@ -141,17 +141,18 @@ final class InterpreterTests: XCTestCase { for test in tests { let env = Environment() try env.set(name: "True", value: true) - for (key, value) in test.data { try env.set(name: key, value: value) } - let tokens = try tokenize(test.template, options: test.options) let parsed = try parse(tokens: tokens) let interpreter = Interpreter(env: env) - let result = try interpreter.run(program: parsed).value as! String - - XCTAssertEqual(result.debugDescription, test.target.debugDescription) + let result = try interpreter.run(program: parsed) + if let stringResult = result as? StringValue { + XCTAssertEqual(stringResult.value.debugDescription, test.target.debugDescription) + } else { + XCTFail("Expected a StringValue, but got \(type(of: result))") + } } } }