diff --git a/.gitignore b/.gitignore index 1d4c618..095f101 100644 --- a/.gitignore +++ b/.gitignore @@ -94,4 +94,8 @@ iOSInjectionProject/ /Packages .netrc .idea -.swiftpm \ No newline at end of file +.swiftpm + +# Specific to this package + +*.code-workspace diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c12932f..28c3445 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/slessans/pre-commit-swift-format - rev: "" + rev: "fd627de92bdf84a75c924ed95691336d14e94cf1" hooks: - id: swift-format args: ["--configuration", ".swift-format"] diff --git a/.swift-format b/.swift-format index 855c4dc..fd02f70 100644 --- a/.swift-format +++ b/.swift-format @@ -8,5 +8,8 @@ "respectsExistingLineBreaks": true, "lineBreakBeforeEachArgument": true, "multiElementCollectionTrailingCommas": true, - "spacesAroundRangeFormationOperators": true + "spacesAroundRangeFormationOperators": true, + "rules": { + "AlwaysUseLowerCamelCase": false + } } diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..cb9509a --- /dev/null +++ b/Package.resolved @@ -0,0 +1,14 @@ +{ + "pins" : [ + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", + "version" : "1.1.4" + } + } + ], + "version" : 2 +} diff --git a/Package.swift b/Package.swift index e7074aa..80ada54 100644 --- a/Package.swift +++ b/Package.swift @@ -13,17 +13,25 @@ let package = Package( targets: ["Jinja"] ) ], + dependencies: [ + .package(url: "https://github.com/apple/swift-collections.git", from: "1.1.4") + ], targets: [ // Targets are the basic building blocks of a package, defining a module or a test suite. // Targets can depend on other targets in this package and products from dependencies. .target( name: "Jinja", + dependencies: [ + .product(name: "OrderedCollections", package: "swift-collections") + ], path: "Sources", swiftSettings: [.enableUpcomingFeature("BareSlashRegexLiterals")] ), .testTarget( name: "JinjaTests", - dependencies: ["Jinja"], + dependencies: [ + "Jinja" + ], path: "Tests", swiftSettings: [.enableUpcomingFeature("BareSlashRegexLiterals")] ), diff --git a/Sources/Ast.swift b/Sources/Ast.swift index 7460284..5d04265 100644 --- a/Sources/Ast.swift +++ b/Sources/Ast.swift @@ -6,6 +6,7 @@ // import Foundation +import OrderedCollections protocol Statement {} @@ -41,7 +42,7 @@ struct TupleLiteral: Literal { } struct ObjectLiteral: Literal { - var value: [(Expression, Expression)] + var value: OrderedDictionary } struct Set: Statement { @@ -49,7 +50,7 @@ struct Set: Statement { var value: Expression } -struct If: Statement { +struct If: Statement, Expression { var test: Expression var body: [Statement] var alternate: [Statement] @@ -59,14 +60,14 @@ struct Identifier: Expression { var value: String } -protocol Loopvar {} -extension Identifier: Loopvar {} -extension TupleLiteral: Loopvar {} +typealias Loopvar = Expression struct For: Statement { var loopvar: Loopvar var iterable: Expression var body: [Statement] + var defaultBlock: [Statement] + var test: Expression? } struct MemberExpression: Expression { @@ -92,7 +93,11 @@ extension CallExpression: Filter {} struct FilterExpression: Expression { var operand: Expression - var filter: Filter + var filter: Identifier + var args: [Expression] + var kwargs: [KeywordArgumentExpression] + var dyn_args: Expression? + var dyn_kwargs: Expression? } struct TestExpression: Expression { @@ -124,3 +129,23 @@ struct KeywordArgumentExpression: Expression { struct NullLiteral: Literal { var value: Any? = nil } + +struct SelectExpression: Expression { + var iterable: Expression + var test: Expression +} + +struct Macro: Statement { + var name: Identifier + var args: [Expression] + var body: [Statement] +} + +struct KeywordArgumentsValue: RuntimeValue { + var value: [String: any RuntimeValue] + var builtins: [String: any RuntimeValue] = [:] + + func bool() -> Bool { + !value.isEmpty + } +} diff --git a/Sources/Environment.swift b/Sources/Environment.swift index c845068..a12477d 100644 --- a/Sources/Environment.swift +++ b/Sources/Environment.swift @@ -6,49 +6,75 @@ // import Foundation +import OrderedCollections class Environment { var parent: Environment? var variables: [String: any RuntimeValue] = [ "namespace": FunctionValue(value: { args, _ in - if args.count == 0 { + if args.isEmpty { return ObjectValue(value: [:]) } - - if args.count != 1 || !(args[0] is ObjectValue) { + guard args.count == 1, let objectArg = args[0] as? ObjectValue else { throw JinjaError.runtime("`namespace` expects either zero arguments or a single object argument") } - - return args[0] + return objectArg }) ] - var tests: [String: (any RuntimeValue...) throws -> Bool] = [ - "boolean": { - args in - args[0] is BooleanValue - }, - - "callable": { - args in - args[0] is FunctionValue - }, - - "odd": { - args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 != 0 + lazy var tests: [String: (any RuntimeValue...) throws -> Bool] = [ + "odd": { args in + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 != 0 } else { - throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of:args.first))") + throw JinjaError.runtime( + "Cannot apply test 'odd' to type: \(type(of: args.first)) with value \(String(describing: args.first?.value))" + ) } }, "even": { args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 == 0 + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 == 0 } else { - throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of:args.first))") + throw JinjaError.runtime( + "Cannot apply test 'even' to type: \(type(of: args.first)) with value \(String(describing: args.first?.value))" + ) + } + }, + "divisibleby": { args in + guard let value = args[0] as? NumericValue, + let num = args[1] as? NumericValue, + let intValue = value.value as? Int, + let intNum = num.value as? Int + else { + throw JinjaError.runtime("divisibleby test requires two integers") + } + return intValue % intNum == 0 + }, + "defined": { args in + return !(args[0] is UndefinedValue) + }, + "undefined": { args in + return args[0] is UndefinedValue + }, + "filter": { [weak self] (args: any RuntimeValue...) throws -> Bool in + guard let name = args[0] as? StringValue else { + throw JinjaError.runtime("filter test requires a string") + } + return self?.filters.keys.contains(name.value) ?? false + }, + "test": { [weak self] (args: any RuntimeValue...) throws -> Bool in + guard let name = args[0] as? StringValue else { + throw JinjaError.runtime("test test requires a string") } + return self?.tests.keys.contains(name.value) ?? false + }, + "none": { args in + return args[0] is NullValue + }, + "boolean": { args in + return args[0] is BooleanValue }, "false": { args in if let arg = args[0] as? BooleanValue { @@ -62,24 +88,22 @@ class Environment { } return false }, - "number": { args in - args[0] is NumericValue - }, "integer": { args in if let arg = args[0] as? NumericValue { return arg.value is Int } - return false }, - "iterable": { args in - args[0] is ArrayValue || args[0] is StringValue + "float": { args in + if let numericValue = args[0] as? NumericValue { + return numericValue.value is Double + } + return false }, "lower": { args in if let arg = args[0] as? StringValue { return arg.value == arg.value.lowercased() } - return false }, "upper": { args in @@ -88,17 +112,1407 @@ class Environment { } return false }, - "none": { args in - args[0] is NullValue + "string": { args in + return args[0] is StringValue }, - "defined": { args in - !(args[0] is UndefinedValue) + "mapping": { args in + return args[0] is ObjectValue }, - "undefined": { args in - args[0] is UndefinedValue + "number": { args in + return args[0] is NumericValue + }, + "sequence": { args in + let value = args[0] + if value is ArrayValue || value is StringValue { + return true + } + return false + }, + "iterable": { args in + return args[0] is ArrayValue || args[0] is StringValue || args[0] is ObjectValue + }, + "callable": { args in + return args[0] is FunctionValue + }, + // TODO: Implement "sameas" + // TODO: Implement "escaped" + "in": { args in + guard let seq = args[1] as? ArrayValue else { + throw JinjaError.runtime("in test requires a sequence") + } + return seq.value.contains { item in + self.doEqualTo([args[0], item]) + } + }, + "==": { args in self.doEqualTo(args) }, + "eq": { args in self.doEqualTo(args) }, + "equalto": { args in self.doEqualTo(args) }, + "!=": { args in + guard args.count == 2 else { + throw JinjaError.runtime("!= test requires two arguments") + } + return !self.doEqualTo(args) + }, + "ne": { args in + guard args.count == 2 else { + throw JinjaError.runtime("ne test requires two arguments") + } + return !self.doEqualTo(args) + }, + ">": { args in + guard args.count == 2 else { + throw JinjaError.runtime("> test requires two arguments") + } + return try self.doGreaterThan(args) + }, + "gt": { args in + guard args.count == 2 else { + throw JinjaError.runtime("gt test requires two arguments") + } + return try self.doGreaterThan(args) + }, + "greaterthan": { args in + guard args.count == 2 else { + throw JinjaError.runtime("greaterthan test requires two arguments") + } + return try self.doGreaterThan(args) + }, + ">=": { args in + guard args.count == 2 else { + throw JinjaError.runtime(">= test requires two arguments") + } + return try self.doGreaterThanOrEqual(args) + }, + "ge": { args in + guard args.count == 2 else { + throw JinjaError.runtime("ge test requires two arguments") + } + return try self.doGreaterThanOrEqual(args) + }, + "<": { args in + guard args.count == 2 else { + throw JinjaError.runtime("< test requires two arguments") + } + return try self.doLessThan(args) + }, + "lt": { args in + guard args.count == 2 else { + throw JinjaError.runtime("lt test requires two arguments") + } + return try self.doLessThan(args) + }, + "lessthan": { args in + guard args.count == 2 else { + throw JinjaError.runtime("lessthan test requires two arguments") + } + return try self.doLessThan(args) + }, + "<=": { args in + guard args.count == 2 else { + throw JinjaError.runtime("<= test requires two arguments") + } + return try self.doLessThanOrEqual(args) + }, + "le": { args in + guard args.count == 2 else { + throw JinjaError.runtime("le test requires two arguments") + } + return try self.doLessThanOrEqual(args) + }, + ] + + lazy var filters: [String: ([any RuntimeValue], Environment) throws -> any RuntimeValue] = [ + "abs": { args, env in + guard args.count == 1 else { + throw JinjaError.runtime("abs filter requires exactly one argument, but \(args.count) were provided") + } + guard let numericValue = args[0] as? NumericValue else { + throw JinjaError.runtime("abs filter requires a number") + } + if let intValue = numericValue.value as? Int { + let absValue = abs(intValue) + return NumericValue(value: absValue) + } else if let doubleValue = numericValue.value as? Double { + let absValue = abs(doubleValue) + return NumericValue(value: absValue) + } else { + throw JinjaError.runtime("Unsupported numeric type for abs filter") + } + }, + "attr": { args, env in + guard args.count >= 2 else { + throw JinjaError.runtime("attr filter requires an object and attribute name") + } + let obj = args[0] + // Convert name to string (similar to str(name) in Python) + let name: String + if let stringValue = args[1] as? StringValue { + name = stringValue.value + } else { + // Try to convert the name to string + do { + name = try stringify(args[1]) + } catch { + return UndefinedValue() + } + } + // Handle different object types + if let objectValue = obj as? ObjectValue { + // Return the raw value if it exists + if let value = objectValue.value[name] { + return value + } + } + // If attribute is not found, return undefined + return UndefinedValue() + }, + "batch": { args, env in + guard let arrayValue = args[0] as? ArrayValue, + let linecount = args[1] as? NumericValue, + let count = linecount.value as? Int + else { + throw JinjaError.runtime("batch filter requires an array and line count") + } + let fillWith = args.count > 2 ? args[2] : nil + var result: [[any RuntimeValue]] = [] + var temp: [any RuntimeValue] = [] + for item in arrayValue.value { + if temp.count == count { + result.append(temp) + temp = [] + } + temp.append(item) + } + if !temp.isEmpty { + if let fill = fillWith, temp.count < count { + temp += Array(repeating: fill, count: count - temp.count) + } + result.append(temp) + } + return ArrayValue(value: result.map { ArrayValue(value: $0) }) + }, + "capitalize": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("capitalize filter requires a string") + } + let str = stringValue.value + guard let firstChar = str.first else { + return stringValue // Empty string, return as is + } + return StringValue(value: String(firstChar).uppercased() + str.dropFirst().lowercased()) + }, + "center": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("center filter requires a string") + } + let width = (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 80 + let str = stringValue.value + + // If string is longer than width, return original string + if str.count >= width { + return stringValue + } + + // Calculate total padding needed + let padding = width - str.count + + // Calculate left and right padding + // When padding is odd, the extra space goes to the right + let leftPadding = padding / 2 + let rightPadding = padding - leftPadding // This ensures extra padding goes to the right + + // Create padded string + return StringValue( + value: String(repeating: " ", count: leftPadding) + str + String(repeating: " ", count: rightPadding) + ) + }, + "count": { args, env in + let value = args[0] + if let arrayValue = value as? ArrayValue { + return NumericValue(value: arrayValue.value.count) + } else if let stringValue = value as? StringValue { + return NumericValue(value: stringValue.value.count) + } else if let objectValue = value as? ObjectValue { + return NumericValue(value: objectValue.value.count) + } + throw JinjaError.runtime("Cannot count value of type \(type(of: value))") + }, + "d": { args, env in try self.doDefault(args, env) }, + "default": { args, env in try self.doDefault(args, env) }, + "dictsort": { args, env in + guard let dict = args[0] as? ObjectValue else { + throw JinjaError.runtime("dictsort filter requires a dictionary") + } + let caseSensitive = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let by = args.count > 2 ? (args[2] as? StringValue)?.value ?? "key" : "key" + let reverse = args.count > 3 ? (args[3] as? BooleanValue)?.value ?? false : false + let sortedDict = try dict.storage.sorted { (item1, item2) in + let a: Any, b: Any + if by == "key" { + a = item1.key + b = item2.key + } else if by == "value" { + a = item1.value + b = item2.value + } else { + throw JinjaError.runtime("Invalid 'by' argument for dictsort filter") + } + let result: Bool + if let aString = a as? String, let bString = b as? String { + result = caseSensitive ? aString < bString : aString.lowercased() < bString.lowercased() + } else if let aNumeric = a as? NumericValue, let bNumeric = b as? NumericValue { + if let aInt = aNumeric.value as? Int, let bInt = bNumeric.value as? Int { + result = aInt < bInt + } else if let aDouble = aNumeric.value as? Double, let bDouble = bNumeric.value as? Double { + result = aDouble < bDouble + } else { + throw JinjaError.runtime("Cannot compare values in dictsort filter") + } + } else { + throw JinjaError.runtime("Cannot compare values in dictsort filter") + } + return reverse ? !result : result + } + return ArrayValue( + value: sortedDict.map { (key, value) in + return ArrayValue(value: [StringValue(value: key), value]) + } + ) + }, + "e": { args, env in try self.doEscape(args, env) }, + "escape": { args, env in try self.doEscape(args, env) }, + "filesizeformat": { args, env in + guard let value = args[0] as? NumericValue else { + throw JinjaError.runtime("filesizeformat filter requires a numeric value") + } + + let size: Double + if let intValue = value.value as? Int { + size = Double(intValue) + } else if let doubleValue = value.value as? Double { + size = doubleValue + } else { + throw JinjaError.runtime("filesizeformat filter requires a numeric value") + } + + let binary = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let units = + binary + ? [" Bytes", " KiB", " MiB", " GiB", " TiB", " PiB", " EiB", " ZiB", " YiB"] + : [" Bytes", " kB", " MB", " GB", " TB", " PB", " EB", " ZB", " YB"] + let base: Double = binary ? 1024.0 : 1000.0 + + if size < base { + return StringValue(value: "\(Int(size)) Bytes") + } + + let exp = Int(log(size) / log(base)) + let unit = units[min(exp, units.count - 1)] + let num = size / pow(base, Double(exp)) + return StringValue(value: String(format: "%.1f%@", num, unit)) + }, + "first": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("first filter requires an array") + } + return arrayValue.value.first ?? UndefinedValue() + }, + "float": { args, env in + guard let value = args[0] as? NumericValue else { + return NumericValue(value: 0.0) + } + if let doubleValue = value.value as? Double { + return NumericValue(value: doubleValue) + } else if let intValue = value.value as? Int { + return NumericValue(value: Double(intValue)) + } else { + return NumericValue(value: 0.0) + } + }, + "forceescape": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("forceescape filter requires a string") + } + return StringValue( + value: stringValue.value.replacingOccurrences(of: "&", with: "&") + .replacingOccurrences(of: "<", with: "<") + .replacingOccurrences(of: ">", with: ">") + .replacingOccurrences(of: "\"", with: """) + .replacingOccurrences(of: "'", with: "'") + ) + }, + "format": { args, env in + guard let formatString = args[0] as? StringValue else { + throw JinjaError.runtime("format filter requires a format string") + } + // Get the values after the format string + let formatArgs = Array(args.dropFirst()) + // Convert the values to strings + let formatValues = formatArgs.map { arg -> String in + if let stringValue = arg as? StringValue { + return stringValue.value + } else if let numericValue = arg as? NumericValue { + if let intValue = numericValue.value as? Int { + return String(intValue) + } else if let doubleValue = numericValue.value as? Double { + return String(doubleValue) + } + } + return String(describing: arg) + } + // Replace %s with values one by one + var result = formatString.value + for value in formatValues { + if let range = result.range(of: "%s") { + result.replaceSubrange(range, with: value) + } else if let range = result.range(of: "%d") { + result.replaceSubrange(range, with: value) + } + } + return StringValue(value: result) + }, + "groupby": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("groupby filter requires an array") + } + guard let attribute = args[1] as? StringValue else { + throw JinjaError.runtime("groupby filter requires an attribute name") + } + let defaultValue = args.count > 2 ? args[2] : nil + let caseSensitive = args.count > 3 ? (args[3] as? BooleanValue)?.value ?? false : false + + // Helper function to get nested attribute value + func getAttributeValue(_ obj: ObjectValue, _ path: String) -> any RuntimeValue { + let components = path.split(separator: ".") + var current: any RuntimeValue = obj + + for component in components { + if let currentObj = current as? ObjectValue, + let value = currentObj.value[String(component)] + { + current = value + } else { + return defaultValue ?? UndefinedValue() + } + } + return current + } + + // Sort the array first + let sorted = arrayValue.value.sorted { (a, b) in + guard let aObj = a as? ObjectValue, + let bObj = b as? ObjectValue + else { + return false + } + + let aValue = getAttributeValue(aObj, attribute.value) + let bValue = getAttributeValue(bObj, attribute.value) + + if let aStr = aValue as? StringValue, + let bStr = bValue as? StringValue + { + let aCompare = caseSensitive ? aStr.value : aStr.value.lowercased() + let bCompare = caseSensitive ? bStr.value : bStr.value.lowercased() + return aCompare < bCompare + } + // Add other comparison types as needed + return false + } + + // Group the sorted array + var groups: [(grouper: any RuntimeValue, list: [any RuntimeValue])] = [] + var currentGroup: [any RuntimeValue] = [] + var currentKey: (any RuntimeValue)? = nil // Changed to var and explicitly initialized as nil + + for item in sorted { + guard let obj = item as? ObjectValue else { continue } + let value = getAttributeValue(obj, attribute.value) + let key = + caseSensitive + ? value : (value as? StringValue).map { StringValue(value: $0.value.lowercased()) } ?? value + + if let existingKey = currentKey { // Changed to different name for binding + if self.doEqualTo([key, existingKey]) { + currentGroup.append(item) + } else { + if !currentGroup.isEmpty { + // Use the first item's actual value as the grouper + if let firstObj = currentGroup[0] as? ObjectValue { + let grouper = getAttributeValue(firstObj, attribute.value) + groups.append((grouper: grouper, list: currentGroup)) + } + } + currentGroup = [item] + currentKey = key // Now works because currentKey is var + } + } else { + currentGroup = [item] + currentKey = key + } + } + + // Add the last group + if !currentGroup.isEmpty { + if let firstObj = currentGroup[0] as? ObjectValue { + let grouper = getAttributeValue(firstObj, attribute.value) + groups.append((grouper: grouper, list: currentGroup)) + } + } + + // Convert groups to array of objects with 'grouper' and 'list' keys + return ArrayValue( + value: groups.map { group in + ObjectValue(value: [ + "grouper": group.grouper, + "list": ArrayValue(value: group.list), + ]) + } + ) + }, + "indent": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("indent filter requires a string") + } + // Determine indentation width + var indent: String + if args.count > 1 { + if let width = args[1] as? NumericValue, let intWidth = width.value as? Int { + indent = String(repeating: " ", count: intWidth) + } else if let stringWidth = args[1] as? StringValue { + indent = stringWidth.value + } else { + indent = String(repeating: " ", count: 4) // Default + } + } else { + indent = String(repeating: " ", count: 4) // Default + } + let first = args.count > 2 ? (args[2] as? BooleanValue)?.value ?? false : false + let blank = args.count > 3 ? (args[3] as? BooleanValue)?.value ?? false : false + // Add a newline to the end of the string (Python quirk) + let modifiedStringValue = stringValue.value + "\n" + // Split into lines + var lines = modifiedStringValue.components(separatedBy: "\n") + // Remove the last line (which is always empty due to the added newline) + lines.removeLast() + if lines.isEmpty { + return StringValue(value: "") + } + var result: String + // Handle first line + if first { + result = indent + lines[0] + } else { + result = lines[0] + } + // Process remaining lines + if lines.count > 1 { + let remainingLines = lines.dropFirst().map { line -> String in + if line.isEmpty { + return blank ? indent + line : line + } else { + return indent + line + } + } + result += "\n" + remainingLines.joined(separator: "\n") + } + return StringValue(value: result) + }, + "int": { args, env in + if let numericValue = args[0] as? NumericValue { + if let intValue = numericValue.value as? Int { + return NumericValue(value: intValue) + } else if let doubleValue = numericValue.value as? Double { + return NumericValue(value: Int(doubleValue)) + } + } else if let stringValue = args[0] as? StringValue { + if let intValue = Int(stringValue.value) { + return NumericValue(value: intValue) + } else if let doubleValue = Double(stringValue.value) { + return NumericValue(value: Int(doubleValue)) + } + } + // Return 0 for any other case (including invalid strings) + return NumericValue(value: 0) + }, + "items": { args, env in + guard let value = args.first else { + throw JinjaError.runtime("items filter requires an argument") + } + // Handle undefined values by returning empty array + if value is UndefinedValue { + return ArrayValue(value: []) + } + // Handle objects (mappings) + if let objectValue = value as? ObjectValue { + return ArrayValue( + value: objectValue.storage.map { (key, value) in + ArrayValue(value: [StringValue(value: key), value]) + } + ) + } + + throw JinjaError.runtime("Can only get item pairs from a mapping.") + }, + "join": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("join filter requires an array") + } + let separator = (args.count > 1 && args[1] is StringValue) ? (args[1] as! StringValue).value : "" + // Convert all values to strings before joining + let stringValues = try arrayValue.value.map { value -> String in + if let stringValue = value as? StringValue { + return stringValue.value + } else { + // Convert other types to string using stringify function + return try stringify(value) + } + } + return StringValue(value: stringValues.joined(separator: separator)) + }, + "last": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("last filter requires an array") + } + return arrayValue.value.last ?? UndefinedValue() + }, + "length": { args, env in + guard let arg = args.first else { + throw JinjaError.runtime("length filter expects one argument") + } + + if let arrayValue = arg as? ArrayValue { + return NumericValue(value: arrayValue.value.count) + } else if let stringValue = arg as? StringValue { + return NumericValue(value: stringValue.value.count) + } else if let objectValue = arg as? ObjectValue { + return NumericValue(value: objectValue.value.count) + } else { + throw JinjaError.runtime("Cannot get length of type: \(type(of: arg))") + } + }, + "list": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("list filter requires an array") + } + return arrayValue + }, + "lower": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("lower filter requires a string") + } + return StringValue(value: stringValue.value.lowercased()) + }, + "map": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + // Handle None/empty case + if args[0] is NullValue { + return ArrayValue(value: []) + } + throw JinjaError.runtime("map filter requires an array") + } + // Handle attribute mapping + if args.count >= 2, let kwargs = args.last as? ObjectValue, + let attribute = kwargs.value["attribute"] as? StringValue + { + let defaultValue = kwargs.value["default"] // Get default value if provided + return ArrayValue( + value: arrayValue.value.map { item in + if let objectValue = item as? ObjectValue { + if let value = objectValue.value[attribute.value] { + if value is UndefinedValue { + // If value is explicitly undefined, return "None" + return StringValue(value: "None") + } + if value is NullValue { + // If value is explicitly null, return default if provided + return defaultValue ?? StringValue(value: "None") + } + return value + } else { + // If attribute doesn't exist, use default + return defaultValue ?? StringValue(value: "None") + } + } + return defaultValue ?? StringValue(value: "None") + } + ) + } + // Handle function mapping by name + if let functionName = args[1] as? StringValue { + guard let filter = env.filters[functionName.value] else { + throw JinjaError.runtime("Unknown function: \(functionName.value)") + } + return ArrayValue( + value: try arrayValue.value.map { item in + try filter([item], env) + } + ) + } + throw JinjaError.runtime("map filter requires either an attribute name or a function name") + }, + "min": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("min filter requires an array") + } + if arrayValue.value.isEmpty { + return UndefinedValue() + } + if let numericValues = arrayValue.value as? [NumericValue] { + let ints = numericValues.compactMap { $0.value as? Int } + let doubles = numericValues.compactMap { $0.value as? Double } + if !ints.isEmpty, doubles.isEmpty { + if let min = ints.min() { + return NumericValue(value: min) + } else { + throw JinjaError.runtime("min value of array in min filter could not be determined") + } + } else if !doubles.isEmpty, ints.isEmpty { + if let min = doubles.min() { + return NumericValue(value: min) + } else { + throw JinjaError.runtime("min value of array in min filter could not be determined") + } + } else { + throw JinjaError.runtime("min filter requires all array elements to be of type Int or Double") + } + } else if let stringValues = arrayValue.value as? [StringValue] { + return StringValue(value: stringValues.map { $0.value }.min() ?? "") + } else { + throw JinjaError.runtime("min filter requires an array of numbers or strings") + } + }, + "max": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("max filter requires an array") + } + if arrayValue.value.isEmpty { + return UndefinedValue() + } + if let numericValues = arrayValue.value as? [NumericValue] { + let ints = numericValues.compactMap { $0.value as? Int } + let doubles = numericValues.compactMap { $0.value as? Double } + if !ints.isEmpty, doubles.isEmpty { + if let max = ints.max() { + return NumericValue(value: max) + } else { + throw JinjaError.runtime("max value of array in max filter cannot be determined") + } + } else if !doubles.isEmpty, ints.isEmpty { + if let max = doubles.max() { + return NumericValue(value: max) + } else { + throw JinjaError.runtime("max value of array in max filter cannot be determined") + } + } else { + throw JinjaError.runtime("max filter requires all array elements to be of type Int or Double") + } + } else if let stringValues = arrayValue.value as? [StringValue] { + return StringValue(value: stringValues.map { $0.value }.max() ?? "") + } else { + throw JinjaError.runtime("max filter requires an array of numbers or strings") + } + }, + "pprint": { args, env in + guard let value = args.first else { + throw JinjaError.runtime("pprint filter expects one argument") + } + return StringValue(value: String(describing: value)) + }, + "random": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("random filter requires an array") + } + if let randomIndex = arrayValue.value.indices.randomElement() { + return arrayValue.value[randomIndex] + } else { + return UndefinedValue() + } + }, + "reject": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("reject filter requires an array") + } + guard let testName = args[1] as? StringValue else { + throw JinjaError.runtime("reject filter requires a test name") + } + guard let test = env.tests[testName.value] else { + throw JinjaError.runtime("Unknown test '\(testName.value)'") + } + var result: [any RuntimeValue] = [] + for item in arrayValue.value { + // Correctly pass arguments to the test function + if try !test(item) { // Negate the result for 'reject' + result.append(item) + } + } + return ArrayValue(value: result) + }, + "rejectattr": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("rejectattr filter requires an array") + } + guard let attribute = args[1] as? StringValue else { + throw JinjaError.runtime("rejectattr filter requires an attribute name") + } + var result: [any RuntimeValue] = [] + for item in arrayValue.value { + guard let objectValue = item as? ObjectValue, + let attrValue = objectValue.value[attribute.value] + else { + continue + } + if args.count == 2 { + if !attrValue.bool() { + result.append(item) + } + } else { + let testName = (args[2] as? StringValue)?.value ?? "defined" + guard let test = env.tests[testName] else { + throw JinjaError.runtime("Unknown test '\(testName)'") + } + // Correctly pass arguments to the test function + if try !test(attrValue) { // Note the negation (!) for rejectattr + result.append(item) + } + } + } + return ArrayValue(value: result) + }, + "replace": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("replace filter requires a string") + } + guard let oldValue = args[1] as? StringValue else { + throw JinjaError.runtime("replace filter requires an old value string") + } + guard let newValue = args[2] as? StringValue else { + throw JinjaError.runtime("replace filter requires a new value string") + } + let count = (args.count > 3 && args[3] is NumericValue) ? (args[3] as! NumericValue).value as! Int : Int.max + return StringValue( + value: stringValue.value.replacingOccurrences( + of: oldValue.value, + with: newValue.value, + options: [], + range: nil + ) + ) + }, + "reverse": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("reverse filter requires an array") + } + return ArrayValue(value: arrayValue.value.reversed()) + }, + "round": { args, env in + guard let value = args[0] as? NumericValue, let number = value.value as? Double else { + throw JinjaError.runtime("round filter requires a number") + } + let precision = (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 0 + let method = (args.count > 2 && args[2] is StringValue) ? (args[2] as! StringValue).value : "common" + let factor = pow(10, Double(precision)) + let roundedNumber: Double + if method == "common" { + roundedNumber = round(number * factor) / factor + } else if method == "ceil" { + roundedNumber = ceil(number * factor) / factor + } else if method == "floor" { + roundedNumber = floor(number * factor) / factor + } else { + throw JinjaError.runtime("Invalid method for round filter") + } + return NumericValue(value: roundedNumber) + }, + "safe": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("safe filter requires a string") + } + return stringValue // In this minimal example, we don't handle marking strings as safe + }, + "select": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("select filter requires an array") + } + guard let testName = args[1] as? StringValue else { + throw JinjaError.runtime("select filter requires a test name") + } + guard let test = env.tests[testName.value] else { + throw JinjaError.runtime("Unknown test '\(testName.value)'") + } + var result: [any RuntimeValue] = [] + for item in arrayValue.value { + if try test(item) { + result.append(item) + } + } + return ArrayValue(value: result) + }, + "selectattr": { args, env in + guard let array = args[0] as? ArrayValue else { + throw JinjaError.runtime("selectattr filter requires an array") + } + guard let attribute = args[1] as? StringValue else { + throw JinjaError.runtime("selectattr filter requires an attribute name") + } + guard args.count > 2 else { + throw JinjaError.runtime("selectattr filter requires a test") + } + var result: [any RuntimeValue] = [] + for item in array.value { + if let obj = item as? ObjectValue, + let attrValue = obj.value[attribute.value] + { + if args[2] is StringValue && args[2].bool() { + // Simple boolean test + if attrValue.bool() { + result.append(item) + } + } else if args.count > 3 { + // Test with comparison value + if let testName = (args[2] as? StringValue)?.value { + let testValue = args[3] + if testName == "equalto" { + // Handle equality test + if let strAttr = attrValue as? StringValue, + let strTest = testValue as? StringValue + { + if strAttr.value == strTest.value { + result.append(item) + } + } + } + } + } + } + } + return ArrayValue(value: result) + }, + "slice": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("slice filter requires an array") + } + guard let slicesValue = args[1] as? NumericValue, + let slices = slicesValue.value as? Int, + slices > 0 + else { + throw JinjaError.runtime("slice filter requires a positive number of slices") + } + + let fillWith = args.count > 2 ? args[2] : nil + let seq = arrayValue.value + let length = seq.count + let itemsPerSlice = length / slices + let slicesWithExtra = length % slices + var offset = 0 + + var result: [[any RuntimeValue]] = [] + + for sliceNumber in 0 ..< slices { + let start = offset + sliceNumber * itemsPerSlice + + if sliceNumber < slicesWithExtra { + offset += 1 + } + + let end = offset + (sliceNumber + 1) * itemsPerSlice + var tmp = Array(seq[start ..< end]) + + if let fillWith = fillWith, sliceNumber >= slicesWithExtra { + tmp.append(fillWith) + } + + result.append(tmp) + } + + return ArrayValue(value: result.map { ArrayValue(value: $0) }) + }, + "sort": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("sort filter requires an array") + } + + let reverse = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let caseSensitive = args.count > 2 ? (args[2] as? BooleanValue)?.value ?? false : false + let attributeStr = args.count > 3 ? (args[3] as? StringValue)?.value : nil + + // Helper function to get value from dot notation path + func getValueFromPath(_ obj: any RuntimeValue, _ path: String) throws -> any RuntimeValue { + let components = path.split(separator: ".") + var current: any RuntimeValue = obj + + for component in components { + if let currentObj = current as? ObjectValue, + let value = currentObj.value[String(component)] + { + current = value + } else if let currentArray = current as? ArrayValue, + let index = Int(component), + index >= 0 && index < currentArray.value.count + { + current = currentArray.value[index] + } else { + throw JinjaError.runtime("Cannot access '\(component)' in path '\(path)'") + } + } + return current + } + + // Helper function to compare RuntimeValues + func compare(_ a: any RuntimeValue, _ b: any RuntimeValue) throws -> Bool { + if let aStr = a as? StringValue, let bStr = b as? StringValue { + if caseSensitive { + return aStr.value < bStr.value + } else { + return aStr.value.lowercased() < bStr.value.lowercased() + } + } else if let aNum = a as? NumericValue, let bNum = b as? NumericValue { + if let aInt = aNum.value as? Int, let bInt = bNum.value as? Int { + return aInt < bInt + } else if let aDouble = aNum.value as? Double, let bDouble = bNum.value as? Double { + return aDouble < bDouble + } else if let aInt = aNum.value as? Int, let bDouble = bNum.value as? Double { + return Double(aInt) < bDouble + } else if let aDouble = aNum.value as? Double, let bInt = bNum.value as? Int { + return aDouble < Double(bInt) + } + } + throw JinjaError.runtime("Cannot compare values of different types") + } + + // Sort the array + let sortedArray = try arrayValue.value.sorted { (a, b) -> Bool in + if let attributeStr = attributeStr { + // Handle multiple attributes (comma-separated) + let attributes = attributeStr.split(separator: ",").map(String.init) + + for attribute in attributes { + let aValue = try getValueFromPath(a, attribute.trimmingCharacters(in: .whitespaces)) + let bValue = try getValueFromPath(b, attribute.trimmingCharacters(in: .whitespaces)) + + // If values are equal, continue to next attribute + if try compare(aValue, bValue) == compare(bValue, aValue) { + continue + } + + return reverse ? try !compare(aValue, bValue) : try compare(aValue, bValue) + } + // All attributes were equal + return false + } else { + return reverse ? try !compare(a, b) : try compare(a, b) + } + } + + return ArrayValue(value: sortedArray) + }, + "string": { args, env in + guard let arg = args.first else { + throw JinjaError.runtime("string filter expects one argument") + } + // In Jinja2 in Python, the `string` filter calls Python's `str` function on dicts, which which uses single quotes for strings. Here we're using double quotes in `tojson`, which is probably better for LLMs anyway, but this will result in differences with output from Jinja2. + return try StringValue(value: stringify(arg, whitespaceControl: true)) + }, + "striptags": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("striptags filter requires a string") + } + // A very basic implementation to remove HTML tags + let tagPattern = #"<[^>]+>"# + let noTagsString = stringValue.value.replacingOccurrences( + of: tagPattern, + with: "", + options: .regularExpression + ) + return StringValue(value: noTagsString) + }, + "sum": { args, env in + guard let arrayValue = args[0] as? ArrayValue else { + throw JinjaError.runtime("sum filter requires an array") + } + + // Get attribute and start value from arguments + let attribute = args.count > 1 ? args[1] : nil + let start: Double = { + if args.count > 2, let numericValue = args[2] as? NumericValue { + if let intValue = numericValue.value as? Int { + return Double(intValue) + } else if let doubleValue = numericValue.value as? Double { + return doubleValue + } + } + return 0.0 + }() + + // Helper function to get value based on attribute + func getValue(_ item: any RuntimeValue) throws -> Double { + if let attribute = attribute { + // Handle string attribute (object property) + if let strAttr = attribute as? StringValue, + let objectValue = item as? ObjectValue, + let attrValue = objectValue.value[strAttr.value] + { + if let numericValue = attrValue as? NumericValue { + if let intValue = numericValue.value as? Int { + return Double(intValue) + } else if let doubleValue = numericValue.value as? Double { + return doubleValue + } + } + throw JinjaError.runtime("Attribute '\(strAttr.value)' is not numeric") + } + // Handle integer attribute (array/string index) + else if let numAttr = attribute as? NumericValue, + let index = numAttr.value as? Int + { + if let arrayValue = item as? ArrayValue { + guard index >= 0 && index < arrayValue.value.count else { + throw JinjaError.runtime("Index \(index) out of range") + } + if let numericValue = arrayValue.value[index] as? NumericValue { + if let intValue = numericValue.value as? Int { + return Double(intValue) + } else if let doubleValue = numericValue.value as? Double { + return doubleValue + } + } + throw JinjaError.runtime("Value at index \(index) is not numeric") + } + } + throw JinjaError.runtime("Cannot get attribute '\(try stringify(attribute))' from item") + } else { + // No attribute - use item directly + if let numericValue = item as? NumericValue { + if let intValue = numericValue.value as? Int { + return Double(intValue) + } else if let doubleValue = numericValue.value as? Double { + return doubleValue + } + } + throw JinjaError.runtime("Item is not numeric") + } + } + + // Sum all values + var result = start + for item in arrayValue.value { + do { + result += try getValue(item) + } catch { + throw JinjaError.runtime("Could not sum items: \(error.localizedDescription)") + } + } + + // Return result as NumericValue + // If the result has no decimal part, return as Int + if result.truncatingRemainder(dividingBy: 1) == 0 { + return NumericValue(value: Int(result)) + } + return NumericValue(value: result) + }, + "title": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("title filter requires a string") + } + + // Split the string by spaces, hyphens, and opening brackets/braces/parentheses + let pattern = "([-\\s(\\{\\[<]+)" + let regex = try! NSRegularExpression(pattern: pattern, options: []) + let str = stringValue.value + let range = NSRange(str.startIndex ..< str.endIndex, in: str) + + // Split the string and keep the delimiters + let matches = regex.matches(in: str, options: [], range: range) + var parts: [String] = [] + var currentIndex = str.startIndex + + // Add the first part if it exists + if let firstMatch = matches.first, + let firstMatchRange = Range(firstMatch.range, in: str) + { + if currentIndex < firstMatchRange.lowerBound { + parts.append(String(str[currentIndex ..< firstMatchRange.lowerBound])) + } + parts.append(String(str[firstMatchRange])) + currentIndex = firstMatchRange.upperBound + } + + // Add remaining parts and delimiters + for i in 1 ..< matches.count { + if let matchRange = Range(matches[i].range, in: str) { + if currentIndex < matchRange.lowerBound { + parts.append(String(str[currentIndex ..< matchRange.lowerBound])) + } + parts.append(String(str[matchRange])) + currentIndex = matchRange.upperBound + } + } + + // Add the last part if it exists + if currentIndex < str.endIndex { + parts.append(String(str[currentIndex ..< str.endIndex])) + } + + // Process each part and join them + let result = parts.filter { !$0.isEmpty }.map { part -> String in + if part.matches(of: try! Regex(pattern)).isEmpty { + // This is a word part, not a delimiter + if let first = part.first { + return String(first).uppercased() + part.dropFirst().lowercased() + } + return part + } + // This is a delimiter, keep it as is + return part + }.joined() + + return StringValue(value: result) + }, + "trim": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("trim filter requires a string") + } + return StringValue(value: stringValue.value.trimmingCharacters(in: .whitespacesAndNewlines)) }, - "equalto": { _ in - throw JinjaError.syntaxNotSupported("equalto") + "truncate": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("truncate filter requires a string") + } + let length = (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 255 + let killwords = (args.count > 2 && args[2] is BooleanValue) ? (args[2] as! BooleanValue).value : false + let end = (args.count > 3 && args[3] is StringValue) ? (args[3] as! StringValue).value : "..." + if stringValue.value.count <= length { + return stringValue + } + if killwords { + return StringValue(value: String(stringValue.value.prefix(length - end.count)) + end) + } else { + let truncated = String(stringValue.value.prefix(length - end.count)) + if let lastSpace = truncated.lastIndex(of: " ") { + return StringValue(value: String(truncated[.. [any RuntimeValue] { + switch value { + case let arrayValue as ArrayValue: + return arrayValue.value + case let stringValue as StringValue: + // Always split string into characters as StringValues + return stringValue.value.map { StringValue(value: String($0)) } + case let objectValue as ObjectValue: + return objectValue.storage.map { key, value in + ArrayValue(value: [StringValue(value: key), value]) + } + default: + throw JinjaError.runtime("Value must be iterable (array, string, or object)") + } + } + // Get the input iterable + guard let input = args.first else { + throw JinjaError.runtime("unique filter requires an iterable") + } + let caseSensitive = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? false : false + let attribute = args.count > 2 ? args[2] : nil + // Helper function to get value based on attribute + func getValue(_ item: any RuntimeValue) throws -> String { + if let attribute = attribute { + // Handle string attribute (object property) + if let strAttr = attribute as? StringValue, + let objectValue = item as? ObjectValue + { + // Support dot notation + let components = strAttr.value.split(separator: ".") + var current: any RuntimeValue = objectValue + + for component in components { + if let currentObj = current as? ObjectValue, + let value = currentObj.value[String(component)] + { + current = value + } else { + throw JinjaError.runtime("Cannot access '\(component)' in path '\(strAttr.value)'") + } + } + return try stringify(current) + } + // Handle integer attribute (array/string index) + else if let numAttr = attribute as? NumericValue, + let index = numAttr.value as? Int + { + if let stringValue = item as? StringValue { + let str = stringValue.value + guard index >= 0 && index < str.count else { + throw JinjaError.runtime("Index \(index) out of range") + } + let stringIndex = str.index(str.startIndex, offsetBy: index) + return String(str[stringIndex]) + } else if let arrayValue = item as? ArrayValue { + guard index >= 0 && index < arrayValue.value.count else { + throw JinjaError.runtime("Index \(index) out of range") + } + return try stringify(arrayValue.value[index]) + } + } + } + // No attribute - use item directly + return try stringify(item) + } + var seen: [String: Bool] = [:] + var result: [any RuntimeValue] = [] + // Process all items from the iterable + let items = try getIterableItems(input) + for item in items { + let key = try getValue(item) + let lookupKey = caseSensitive ? key : key.lowercased() + + if seen[lookupKey] == nil { + seen[lookupKey] = true + result.append(item) + } + } + return ArrayValue(value: result) + }, + "upper": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("upper filter requires a string") + } + return StringValue(value: stringValue.value.uppercased()) + }, + "urlencode": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("urlencode filter requires a string") + } + + let encodedString = stringValue.value.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? "" + return StringValue(value: encodedString) + }, + "urlize": { args, env in + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("urlize filter requires a string") + } + let trimUrlLimit = + (args.count > 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as? Int : nil + let nofollow = (args.count > 2 && args[2] is BooleanValue) ? (args[2] as! BooleanValue).value : false + let target = (args.count > 3 && args[3] is StringValue) ? (args[3] as! StringValue).value : nil + let urlPattern = + #"(https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s]{2,}|www\.[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s]{2,}|https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9]+\.[^\s]{2,}|www\.[a-zA-Z0-9]+\.[^\s]{2,})"# + var urlizedString = stringValue.value + if let regex = try? NSRegularExpression(pattern: urlPattern, options: []) { + let nsRange = NSRange( + stringValue.value.startIndex ..< stringValue.value.endIndex, + in: stringValue.value + ) + let matches = regex.matches(in: stringValue.value, options: [], range: nsRange) + + for match in matches.reversed() { + let urlRange = Range(match.range, in: stringValue.value)! + let url = String(stringValue.value[urlRange]) + var trimmedUrl = url + if let limit = trimUrlLimit, url.count > limit { + trimmedUrl = String(url.prefix(limit)) + "..." + } + var link = " 1 && args[1] is NumericValue) ? (args[1] as! NumericValue).value as! Int : 79 + let breakLongWords = (args.count > 2 && args[2] is BooleanValue) ? (args[2] as! BooleanValue).value : true + let wrapString = (args.count > 3 && args[3] is StringValue) ? (args[3] as! StringValue).value : "\n" + var result = "" + var currentLineLength = 0 + for word in stringValue.value.split(separator: " ", omittingEmptySubsequences: false) { + if currentLineLength + word.count > width { + if currentLineLength > 0 { + result += wrapString + currentLineLength = 0 + } + if word.count > width && breakLongWords { + var remainingWord = word[...] + while remainingWord.count > width { + result += remainingWord.prefix(width) + result += wrapString + remainingWord = remainingWord.dropFirst(width) + } + if !remainingWord.isEmpty { + result += remainingWord + currentLineLength = remainingWord.count + } + continue + } + } + if !result.isEmpty && currentLineLength == 0 { + result += word + currentLineLength = word.count + } else { + if !result.isEmpty { + result += " " + currentLineLength += 1 + } + result += word + currentLineLength += word.count + } + } + return StringValue(value: result) + }, + "xmlattr": { args, env in + guard let dict = args[0] as? ObjectValue else { + throw JinjaError.runtime("xmlattr filter requires a dictionary") + } + let autospace = args.count > 1 ? (args[1] as? BooleanValue)?.value ?? true : true + var result = "" + for (key, value) in dict.storage { + if !(value is UndefinedValue) && !(value is NullValue) { + if autospace { + result += " " + } + if let stringValue = value as? StringValue { + result += + "\(key)=\"\(stringValue.value.replacingOccurrences(of: "&", with: "&").replacingOccurrences(of: "\"", with: """))\"" + } else { + result += "\(key)=\"\(value)\"" + } + } + } + return StringValue(value: result) + }, + "tojson": { args, env in + guard let firstArg = args.first else { + throw JinjaError.runtime("tojson filter expects at least one argument") + } + var indent: Int? = nil + if args.count > 1, let kwargs = args.last as? ObjectValue, + let indentArg = kwargs.value["indent"] as? NumericValue, + let indentInt = indentArg.value as? Int + { + indent = indentInt + } + return try StringValue(value: toJSON(firstArg, indent: indent, whitespaceControl: false)) }, ] @@ -106,82 +1520,132 @@ class Environment { self.parent = parent } - func isFunction(_ value: Any, functionType: T.Type) -> Bool { - value is T - } + // func isFunction(_ value: Any, functionType: T.Type) -> Bool { + // return value is T + // } - func convertToRuntimeValues(input: Any) throws -> any RuntimeValue { + func convertToRuntimeValues(input: Any?) throws -> any RuntimeValue { + // Handle already converted RuntimeValue + if let runtimeValue = input as? any RuntimeValue { + return runtimeValue + } + // Handle nil values + if input == nil { + return NullValue() + } + if case Optional.none = input { + return NullValue() + } + // Helper function to handle any OrderedDictionary type + func convertOrderedDictionary(_ dict: OrderedDictionary) throws -> ObjectValue { + var object: [String: any RuntimeValue] = [:] + var keyOrder: [String] = [] + + for (key, value) in dict { + // Crucial: Convert Optional to T, using NullValue if nil + let convertedValue = (value as Any?) ?? NullValue() + object[key] = try self.convertToRuntimeValues(input: convertedValue) + keyOrder.append(key) + } + return ObjectValue(value: object, keyOrder: keyOrder) + } + // Handle other values switch input { case let value as Bool: return BooleanValue(value: value) - case let values as [any Numeric]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) - } - return ArrayValue(value: items) - case let value as any Numeric: + case let value as Int: + return NumericValue(value: value) + case let value as Double: + return NumericValue(value: value) + case let value as Float: return NumericValue(value: value) case let value as String: return StringValue(value: value) + case let data as Data: + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert data to string") + } + return StringValue(value: string) case let fn as (String) throws -> Void: return FunctionValue { args, _ in - var arg = "" - switch args[0].value { - case let value as String: - arg = value - case let value as Bool: - arg = String(value) - default: - throw JinjaError.runtime("Unknown arg type:\(type(of: args[0].value))") + guard let stringArg = args[0] as? StringValue else { + throw JinjaError.runtime("Argument must be a StringValue") } - - try fn(arg) + try fn(stringArg.value) return NullValue() } case let fn as (Bool) throws -> Void: return FunctionValue { args, _ in - try fn(args[0].value as! Bool) + guard let boolArg = args[0] as? BooleanValue else { + throw JinjaError.runtime("Argument must be a BooleanValue") + } + try fn(boolArg.value) return NullValue() } case let fn as (Int, Int?, Int) -> [Int]: return FunctionValue { args, _ in - let result = fn(args[0].value as! Int, args[1].value as? Int, args[2].value as! Int) - return try self.convertToRuntimeValues(input: result) - } - case let values as [Any]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) + guard args.count > 0, let arg0 = args[0] as? NumericValue, let int0 = arg0.value as? Int else { + throw JinjaError.runtime("First argument must be an Int") + } + var int1: Int? = nil + if args.count > 1 { + if let numericValue = args[1] as? NumericValue, let tempInt1 = numericValue.value as? Int { + int1 = tempInt1 + } else if !(args[1] is NullValue) { // Accept NullValue for optional second argument + throw JinjaError.runtime("Second argument must be an Int or nil") + } + } + var int2: Int = 1 + if args.count > 2 { + if let numericValue = args[2] as? NumericValue, let tempInt2 = numericValue.value as? Int { + int2 = tempInt2 + } else { + throw JinjaError.runtime("Third argument must be an Int") + } + } + let result = fn(int0, int1, int2) + return ArrayValue(value: result.map { NumericValue(value: $0) }) } + case let values as [Any?]: + let items = try values.map { try self.convertToRuntimeValues(input: $0) } return ArrayValue(value: items) - case let dictionary as [String: String]: + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary>: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary>: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let orderedDict as OrderedDictionary: + return try convertOrderedDictionary(orderedDict) + case let dictionary as [String: Any?]: var object: [String: any RuntimeValue] = [:] - + var keyOrder: [String] = [] for (key, value) in dictionary { - object[key] = StringValue(value: value) + object[key] = try self.convertToRuntimeValues(input: value) + keyOrder.append(key) } - - return ObjectValue(value: object) - case is NullValue: - return NullValue() + return ObjectValue(value: object, keyOrder: keyOrder) default: - throw JinjaError.runtime("Cannot convert to runtime value: \(input) type:\(type(of: input))") + throw JinjaError.runtime( + "Cannot convert to runtime value: \(String(describing: input)) type:\(type(of: input))" + ) } } @discardableResult func set(name: String, value: Any) throws -> any RuntimeValue { - try self.declareVariable(name: name, value: self.convertToRuntimeValues(input: value)) + let runtimeValue = try self.convertToRuntimeValues(input: value) + return try self.declareVariable(name: name, value: runtimeValue) } - func declareVariable(name: String, value: any RuntimeValue) throws -> any RuntimeValue { - if self.variables.contains(where: { $0.0 == name }) { + private func declareVariable(name: String, value: any RuntimeValue) throws -> any RuntimeValue { + if self.variables.keys.contains(name) { throw JinjaError.syntax("Variable already declared: \(name)") } self.variables[name] = value - return value } @@ -191,13 +1655,13 @@ class Environment { return value } - func resolve(name: String) throws -> Self { - if self.variables.contains(where: { $0.0 == name }) { + private func resolve(name: String) throws -> Environment { + if self.variables.keys.contains(name) { return self } - if let parent { - return try parent.resolve(name: name) as! Self + if let parent = self.parent { + return try parent.resolve(name: name) } throw JinjaError.runtime("Unknown variable: \(name)") @@ -205,13 +1669,108 @@ class Environment { func lookupVariable(name: String) -> any RuntimeValue { do { - if let value = try self.resolve(name: name).variables[name] { - return value - } else { - return UndefinedValue() - } + return try self.resolve(name: name).variables[name] ?? UndefinedValue() } catch { return UndefinedValue() } } + + // Filters + + private func doDefault(_ args: [any RuntimeValue], _ env: Environment) throws -> any RuntimeValue { + let value = args[0] + let defaultValue = args.count > 1 ? args[1] : StringValue(value: "") + let boolean = args.count > 2 ? (args[2] as? BooleanValue)?.value ?? false : false + + if value is UndefinedValue { + return defaultValue + } + + if boolean { + if !value.bool() { + return defaultValue + } + // If it's a boolean value, return its string representation + if let boolValue = value as? BooleanValue { + return StringValue(value: String(boolValue.value)) + } + } + + return value + } + + private func doEscape(_ args: [any RuntimeValue], _ env: Environment) throws -> any RuntimeValue { + guard let stringValue = args[0] as? StringValue else { + throw JinjaError.runtime("escape filter requires a string") + } + return StringValue( + value: stringValue.value.replacingOccurrences(of: "&", with: "&") + .replacingOccurrences(of: "<", with: "<") + .replacingOccurrences(of: ">", with: ">") + .replacingOccurrences(of: "\"", with: """) + .replacingOccurrences(of: "'", with: "'") + ) + } + + private func doEqualTo(_ args: [any RuntimeValue]) -> Bool { + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } + } + + // Tests + + private func doGreaterThan(_ args: [any RuntimeValue]) throws -> Bool { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value > right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return leftInt > rightInt + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return leftDouble > rightDouble + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return Double(leftInt) > rightDouble + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return leftDouble > Double(rightInt) + } + } + throw JinjaError.runtime("Cannot compare values of different types") + } + + private func doGreaterThanOrEqual(_ args: [any RuntimeValue]) throws -> Bool { + return try doGreaterThan(args) || doEqualTo(args) + } + + private func doLessThan(_ args: [any RuntimeValue]) throws -> Bool { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value < right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return leftInt < rightInt + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return leftDouble < rightDouble + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return Double(leftInt) < rightDouble + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return leftDouble < Double(rightInt) + } + } + throw JinjaError.runtime("Cannot compare values of different types") + } + + private func doLessThanOrEqual(_ args: [any RuntimeValue]) throws -> Bool { + return try doLessThan(args) || doEqualTo(args) + } } diff --git a/Sources/Error.swift b/Sources/Error.swift index 7fe27b2..b0f6b6e 100644 --- a/Sources/Error.swift +++ b/Sources/Error.swift @@ -7,7 +7,7 @@ import Foundation -enum JinjaError: Error, LocalizedError { +enum JinjaError: Error, LocalizedError, Equatable { case syntax(String) case parser(String) case runtime(String) @@ -23,4 +23,8 @@ enum JinjaError: Error, LocalizedError { case .syntaxNotSupported(let string): return "Syntax not supported: \(string)" } } + + var id: String { + errorDescription ?? "" + } } diff --git a/Sources/Lexer.swift b/Sources/Lexer.swift index 1093960..3c9849d 100644 --- a/Sources/Lexer.swift +++ b/Sources/Lexer.swift @@ -50,6 +50,8 @@ enum TokenType: String { case and = "And" case or = "Or" case not = "Not" + case macro = "Macro" + case endMacro = "EndMacro" } struct Token: Equatable { @@ -70,6 +72,8 @@ let keywords: [String: TokenType] = [ "and": .and, "or": .or, "not": .not, + "macro": .macro, + "endmacro": .endMacro, // Literals "true": .booleanLiteral, "false": .booleanLiteral, @@ -81,7 +85,7 @@ func isWord(char: String) -> Bool { } func isInteger(char: String) -> Bool { - char.range(of: #"[0-9]"#, options: .regularExpression) != nil + char.range(of: #"^[0-9]+$"#, options: .regularExpression) != nil } func isWhile(char: String) -> Bool { @@ -136,21 +140,16 @@ struct PreprocessOptions { func preprocess(template: String, options: PreprocessOptions = PreprocessOptions()) -> String { var template = template - if template.hasSuffix("\n") { template.removeLast() } - template = template.replacing(#/{#.*?#}/#, with: "{##}") - if options.lstripBlocks == true { template = template.replacing(#/(?m)^[ \t]*({[#%])/#, with: { $0.output.1 }) } - if options.trimBlocks == true { template = template.replacing(#/([#%]})\n/#, with: { $0.output.1 }) } - return template .replacing(#/{##}/#, with: "") @@ -163,7 +162,6 @@ func preprocess(template: String, options: PreprocessOptions = PreprocessOptions func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions()) throws -> [Token] { var tokens: [Token] = [] let src = preprocess(template: source, options: options) - var cursorPosition = 0 @discardableResult @@ -175,17 +173,14 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() if cursorPosition >= src.count { throw JinjaError.syntax("Unexpected end of input") } - let escaped = String(src[cursorPosition]) cursorPosition += 1 - guard let unescaped = escapeCharacters[escaped] else { throw JinjaError.syntax("Unexpected escaped character: \(escaped)") } str.append(unescaped) continue } - str.append(String(src[cursorPosition])) cursorPosition += 1 if cursorPosition >= src.count { @@ -197,7 +192,6 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() main: while cursorPosition < src.count { let lastTokenType = tokens.last?.type - if lastTokenType == nil || lastTokenType == .closeStatement || lastTokenType == .closeExpression { var text = "" @@ -213,18 +207,13 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() continue } } - try consumeWhile(predicate: isWhile) - let char = String(src[cursorPosition]) - if char == "-" || char == "+" { let lastTokenType = tokens.last?.type - if lastTokenType == .text || lastTokenType == nil { throw JinjaError.syntax("Unexpected character: \(char)") } - switch lastTokenType { case .identifier, .numericLiteral, @@ -234,18 +223,13 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() .closeParen, .closeSquareBracket: break - default: cursorPosition += 1 - let num = try consumeWhile(predicate: isInteger) - tokens.append(Token(value: "\(char)\(num)", type: num.isEmpty ? .unaryOperator : .numericLiteral)) - continue } } - for (char, token) in orderedMappingTable { let slice = src.slice(start: cursorPosition, end: cursorPosition + char.count) if slice == char { @@ -254,7 +238,6 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() continue main } } - if char == "'" || char == "\"" { cursorPosition += 1 let str = try consumeWhile { str in @@ -264,30 +247,23 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions() cursorPosition += 1 continue } - if isInteger(char: char) { let num = try consumeWhile(predicate: isInteger) tokens.append(Token(value: num, type: .numericLiteral)) continue } - if isWord(char: char) { let word = try consumeWhile(predicate: isWord) - let type: TokenType = keywords.contains(where: { $0.key == word }) ? keywords[word]! : .identifier - if type == .in, tokens.last?.type == .not { _ = tokens.popLast() tokens.append(Token(value: "not in", type: .notIn)) } else { tokens.append(Token(value: word, type: type)) } - continue } - throw JinjaError.syntax("Unexpected character: \(char)") } - return tokens } diff --git a/Sources/Parser.swift b/Sources/Parser.swift index 648a025..8514945 100644 --- a/Sources/Parser.swift +++ b/Sources/Parser.swift @@ -6,6 +6,7 @@ // import Foundation +import OrderedCollections func parse(tokens: [Token]) throws -> Program { var program = Program() @@ -22,40 +23,31 @@ func parse(tokens: [Token]) throws -> Program { return prev } - func parseArgumentsList() throws -> [Statement] { + func parseArgumentsList() throws -> [Expression] { var args: [Expression] = [] - while !typeof(.closeParen) { var argument = try parseExpression() - if typeof(.equals) { - current += 1 - + current += 1 // consume equals if let identifier = argument as? Identifier { let value = try parseExpression() - argument = KeywordArgumentExpression(key: identifier, value: value as! Expression) + argument = KeywordArgumentExpression(key: identifier, value: value) } else { throw JinjaError.syntax("Expected identifier for keyword argument") } } - - args.append(argument as! Expression) - + args.append(argument) if typeof(.comma) { - current += 1 + current += 1 // consume comma } } - return args } - func parseArgs() throws -> [Statement] { + func parseArgs() throws -> [Expression] { try expect(type: .openParen, error: "Expected opening parenthesis for arguments list") - let args = try parseArgumentsList() - try expect(type: .closeParen, error: "Expected closing parenthesis for arguments list") - return args } @@ -63,69 +55,75 @@ func parse(tokens: [Token]) throws -> Program { try StringLiteral(value: expect(type: .text, error: "Expected text token").value) } - func parseCallExpression(callee: Statement) throws -> CallExpression { - var args: [Expression] = [] - - for arg in try parseArgs() { - args.append(arg as! Expression) + func parseCallExpression(callee: Expression) throws -> Expression { + let args = try parseArgs() + var expression: Expression = CallExpression(callee: callee, args: args) + // Handle potential array indexing after method call + if typeof(.openSquareBracket) { + expression = MemberExpression( + object: expression, + property: try parseMemberExpressionArgumentsList(), + computed: true + ) } - - var callExpression = CallExpression(callee: callee as! Expression, args: args) - + // Handle potential chained method calls if typeof(.openParen) { - callExpression = try parseCallExpression(callee: callExpression) + expression = try parseCallExpression(callee: expression) } - - return callExpression + return expression } - func parseMemberExpressionArgumentsList() throws -> Statement { - var slices: [Statement?] = [] + func parseMemberExpressionArgumentsList() throws -> Expression { + var slices: [Expression?] = [] var isSlice = false - while !typeof(.closeSquareBracket) { if typeof(.colon) { slices.append(nil) - current += 1 + current += 1 // consume colon isSlice = true } else { - try slices.append(parseExpression()) + // Handle negative numbers as indices + if typeof(.additiveBinaryOperator) && tokens[current].value == "-" { + current += 1 // consume the minus sign + if typeof(.numericLiteral) { + let num = tokens[current].value + current += 1 + slices.append(NumericLiteral(value: -Int(num)!)) + } else { + throw JinjaError.syntax("Expected number after minus sign in array index") + } + } else { + slices.append(try parseExpression()) + } if typeof(.colon) { - current += 1 + current += 1 // consume colon isSlice = true } } } - if slices.isEmpty { throw JinjaError.syntax("Expected at least one argument for member/slice expression") } - if isSlice { if slices.count > 3 { throw JinjaError.syntax("Expected 0-3 arguments for slice expression") } - return SliceExpression( - start: slices[0] as? Expression, - stop: slices.count > 1 ? slices[1] as? Expression : nil, - step: slices.count > 2 ? slices[2] as? Expression : nil + start: slices[0], + stop: slices.count > 1 ? slices[1] : nil, + step: slices.count > 2 ? slices[2] : nil ) } - - return slices[0]! + return slices[0]! // normal member expression } - func parseMemberExpression() throws -> Statement { + func parseMemberExpression() throws -> Expression { var object = try parsePrimaryExpression() - while typeof(.dot) || typeof(.openSquareBracket) { let operation = tokens[current] current += 1 - var property: Statement - + var property: Expression let computed = operation.type != .dot - if computed { property = try parseMemberExpressionArgumentsList() try expect(type: .closeSquareBracket, error: "Expected closing square bracket") @@ -134,53 +132,127 @@ func parse(tokens: [Token]) throws -> Program { if !(property is Identifier) { throw JinjaError.syntax("Expected identifier following dot operator") } + // Handle method calls + if typeof(.openParen) { + let methodCall = CallExpression( + callee: MemberExpression(object: object, property: property, computed: false), + args: try parseArgs() + ) + // Handle array indexing after method call + if typeof(.openSquareBracket) { + current += 1 // consume [ + let index = try parseExpression() + try expect(type: .closeSquareBracket, error: "Expected closing square bracket") + object = MemberExpression(object: methodCall, property: index, computed: true) + continue + } + object = methodCall + continue + } } - object = MemberExpression( - object: object as! Expression, - property: property as! Expression, + object: object, + property: property, computed: computed ) } - return object } - func parseCallMemberExpression() throws -> Statement { + func parseCallMemberExpression() throws -> Expression { let member = try parseMemberExpression() - if typeof(.openParen) { return try parseCallExpression(callee: member) } - return member } - func parseFilterExpression() throws -> Statement { + func parseFilterExpression() throws -> Expression { var operand = try parseCallMemberExpression() - while typeof(.pipe) { - current += 1 - var filter = try parsePrimaryExpression() - if !(filter is Identifier) { - throw JinjaError.syntax("Expected identifier for the test") + current += 1 // consume pipe + guard let filterName = try parsePrimaryExpression() as? Identifier else { + throw JinjaError.syntax("Expected identifier for the filter") } - + var args: [Expression] = [] + var kwargs: [KeywordArgumentExpression] = [] + var dyn_args: Expression? + var dyn_kwargs: Expression? if typeof(.openParen) { - filter = try parseCallExpression(callee: filter) + // Handle filter with arguments + (args, kwargs, dyn_args, dyn_kwargs) = try parseCallArgs() } + operand = FilterExpression( + operand: operand, + filter: filterName, + args: args, + kwargs: kwargs, + dyn_args: dyn_args, + dyn_kwargs: dyn_kwargs + ) + } + return operand + } - if let filter = filter as? Filter { - operand = FilterExpression(operand: operand as! Expression, filter: filter) + func parseCallArgs() throws -> ( + [Expression], [KeywordArgumentExpression], Expression?, Expression? + ) { + try expect(type: .openParen, error: "Expected opening parenthesis for arguments list") + var args: [Expression] = [] + var kwargs: [KeywordArgumentExpression] = [] + var dynArgs: Expression? + var dynKwargs: Expression? + var requireComma = false + while !typeof(.closeParen) { + if requireComma { + try expect(type: .comma, error: "Expected comma between arguments") + if typeof(.closeParen) { + break + } } + if typeof(.multiplicativeBinaryOperator), tokens[current].value == "*" { + current += 1 // Consume * + if dynArgs != nil || dynKwargs != nil { + throw JinjaError.syntax("Multiple dynamic positional arguments are not allowed.") + } + dynArgs = try parseExpression() + } else if typeof(.multiplicativeBinaryOperator), tokens[current].value == "**" { + current += 1 // Consume ** + if dynKwargs != nil { + throw JinjaError.syntax("Multiple dynamic keyword arguments are not allowed.") + } + dynKwargs = try parseExpression() + } else { + if typeof(.identifier), tokens.count > current + 1, tokens[current + 1].type == .equals { + // Parse keyword argument + guard let key = try parsePrimaryExpression() as? Identifier else { + throw JinjaError.syntax("Expected identifier for keyword argument key") + } + try expect(type: .equals, error: "Expected '=' after keyword argument key") + let value = try parseExpression() + if dynKwargs != nil { + throw JinjaError.syntax("Keyword arguments must be after dynamic keyword arguments") + } + kwargs.append(KeywordArgumentExpression(key: key, value: value)) + } else { + // Parse positional argument + if !kwargs.isEmpty || dynKwargs != nil { + throw JinjaError.syntax("Positional argument after keyword argument") + } + if dynArgs != nil { + throw JinjaError.syntax("Positional arguments must be after dynamic positional arguments") + } + args.append(try parseExpression()) + } + } + requireComma = true } - - return operand + try expect(type: .closeParen, error: "Expected closing parenthesis for arguments list") + return (args, kwargs, dynArgs, dynKwargs) } - func parseTestExpression() throws -> Statement { + func parseTestExpression() throws -> Expression { var operand = try parseFilterExpression() - while typeof(.is) { current += 1 let negate = typeof(.not) @@ -194,7 +266,7 @@ func parse(tokens: [Token]) throws -> Program { filter = Identifier(value: "none") } if let test = filter as? Identifier { - operand = TestExpression(operand: operand as! Expression, negate: negate, test: test) + operand = TestExpression(operand: operand, negate: negate, test: test) } else { throw JinjaError.syntax("Expected identifier for the test") } @@ -202,96 +274,116 @@ func parse(tokens: [Token]) throws -> Program { return operand } - func parseMultiplicativeExpression() throws -> Statement { + func parseMultiplicativeExpression() throws -> Expression { var left = try parseTestExpression() - while typeof(.multiplicativeBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseTestExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseAdditiveExpression() throws -> Statement { + func parseAdditiveExpression() throws -> Expression { var left = try parseMultiplicativeExpression() while typeof(.additiveBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseMultiplicativeExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseComparisonExpression() throws -> Statement { + func parseComparisonExpression() throws -> Expression { var left = try parseAdditiveExpression() - while typeof(.comparisonBinaryOperator) || typeof(.in) || typeof(.notIn) { + while typeof(.comparisonBinaryOperator) || typeof(.in) || typeof(.notIn) + || (typeof(.is) + && (tokens.count > current + 1 + && (tokens[current + 1].type == .identifier || tokens[current + 1].type == .not))) + { let operation = tokens[current] current += 1 - let right = try parseAdditiveExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + if operation.type == .is { + if typeof(.not) { + current += 1 + if typeof(.identifier), tokens[current].value == "none" { + current += 1 + left = TestExpression(operand: left, negate: true, test: Identifier(value: "none")) + continue + } else { + throw JinjaError.syntax("Expected 'none' after 'is not'") + } + } else if typeof(.identifier), tokens[current].value == "defined" { + current += 1 + left = TestExpression(operand: left, negate: false, test: Identifier(value: "defined")) + continue + } else { + throw JinjaError.syntax("Expected 'defined' or 'not' after 'is'") + } + } else if operation.type == .notIn { + let right = try parseAdditiveExpression() + left = BinaryExpression(operation: operation, left: left, right: right) + } else { + let right = try parseAdditiveExpression() + left = BinaryExpression(operation: operation, left: left, right: right) + } } - return left } - func parseLogicalNegationExpression() throws -> Statement { - var right: UnaryExpression? - - while typeof(.not) { + func parseLogicalNegationExpression() throws -> Expression { + if typeof(.not) { let operation = tokens[current] current += 1 let argument = try parseLogicalNegationExpression() - right = UnaryExpression(operation: operation, argument: argument as! Expression) - } - - if let right { - return right + return UnaryExpression(operation: operation, argument: argument) } else { return try parseComparisonExpression() } } - func parseLogicalAndExpression() throws -> Statement { + func parseLogicalAndExpression() throws -> Expression { var left = try parseLogicalNegationExpression() while typeof(.and) { let operation = tokens[current] current += 1 let right = try parseLogicalNegationExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } - return left } - func parseLogicalOrExpression() throws -> Statement { + func parseLogicalOrExpression() throws -> Expression { var left = try parseLogicalAndExpression() - while typeof(.or) { - let operation = tokens[current] - current += 1 + current += 1 // Consume 'or' let right = try parseLogicalAndExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: Token(value: "or", type: .or), left: left, right: right) } return left } - func parseTernaryExpression() throws -> Statement { + func parseTernaryExpression() throws -> Expression { let a = try parseLogicalOrExpression() if typeof(.if) { - current += 1 - let test = try parseLogicalOrExpression() - try expect(type: .else, error: "Expected else token") - let b = try parseLogicalOrExpression() - return If(test: test as! Expression, body: [a], alternate: [b]) + current += 1 // consume if token + let predicate = try parseLogicalOrExpression() + if typeof(.else) { + // Ternary expression with else + current += 1 // consume else token + let b = try parseLogicalOrExpression() + return If(test: predicate, body: [a], alternate: [b]) + } else { + // Select expression on iterable + return SelectExpression(iterable: a, test: predicate) + } } - return a } - func parseExpression() throws -> Statement { + func parseExpression() throws -> Expression { try parseTernaryExpression() } @@ -299,66 +391,68 @@ func parse(tokens: [Token]) throws -> Program { guard current + types.count <= tokens.count else { return false } - for (index, type) in types.enumerated() { if type != tokens[current + index].type { return false } } - return true } func parseSetStatement() throws -> Statement { let left = try parseExpression() - if typeof(.equals) { - current += 1 - let value = try parseSetStatement() - - return Set(assignee: left as! Expression, value: value as! Expression) + current += 1 // consume equals + // Parse the right-hand side as an expression + let value = try parseExpression() + try expect(type: .closeStatement, error: "Expected closing statement token") + return Set(assignee: left, value: value) } - + // If there's no equals sign, treat it as an expression statement + try expect(type: .closeStatement, error: "Expected closing statement token") return left } func parseIfStatement() throws -> Statement { let test = try parseExpression() - try expect(type: .closeStatement, error: "Expected closing statement token") - var body: [Statement] = [] var alternate: [Statement] = [] - while !(tokens[current].type == .openStatement && (tokens[current + 1].type == .elseIf || tokens[current + 1].type == .else || tokens[current + 1].type == .endIf)) { - try body.append(parseAny()) + body.append(try parseAny()) } if tokens[current].type == .openStatement, tokens[current + 1].type != .endIf { current += 1 if typeof(.elseIf) { try expect(type: .elseIf, error: "Expected elseif token") - try alternate.append(parseIfStatement()) + alternate.append(try parseIfStatement()) } else { try expect(type: .else, error: "Expected else token") try expect(type: .closeStatement, error: "Expected closing statement token") while !(tokens[current].type == .openStatement && tokens[current + 1].type == .endIf) { - try alternate.append(parseAny()) + alternate.append(try parseAny()) } } } - return If(test: test as! Expression, body: body, alternate: alternate) + return If(test: test, body: body, alternate: alternate) } - func parsePrimaryExpression() throws -> Statement { + func parsePrimaryExpression() throws -> Expression { let token = tokens[current] switch token.type { case .numericLiteral: current += 1 - return NumericLiteral(value: Int(token.value) ?? 0) + if let intValue = Int(token.value) { + return NumericLiteral(value: intValue) + } else if let doubleValue = Double(token.value) { + return NumericLiteral(value: doubleValue) + } else { + throw JinjaError.parser("Invalid numeric literal: \(token.value)") + } case .stringLiteral: current += 1 return StringLiteral(value: token.value) @@ -383,7 +477,7 @@ func parse(tokens: [Token]) throws -> Program { current += 1 var values: [Expression] = [] while !typeof(.closeSquareBracket) { - try values.append(parseExpression() as! Expression) + try values.append(parseExpression()) if typeof(.comma) { current += 1 } @@ -392,12 +486,20 @@ func parse(tokens: [Token]) throws -> Program { return ArrayLiteral(value: values) case .openCurlyBracket: current += 1 - var values: [(Expression, Expression)] = [] + var values = OrderedDictionary() while !typeof(.closeCurlyBracket) { let key = try parseExpression() try expect(type: .colon, error: "Expected colon between key and value in object literal") let value = try parseExpression() - values.append((key as! Expression, value as! Expression)) + + if let key = key as? StringLiteral { + values[key.value] = value + } else if let key = key as? Identifier { + values[key.value] = value + } else { + throw JinjaError.syntax("Expected string literal or identifier as key in object literal") + } + if typeof(.comma) { current += 1 } @@ -409,18 +511,18 @@ func parse(tokens: [Token]) throws -> Program { } } - func parseExpressionSequence(primary: Bool = false) throws -> Statement { + func parseExpressionSequence(primary: Bool = false) throws -> Expression { let fn = primary ? parsePrimaryExpression : parseExpression - var expressions: [Expression] = try [fn() as! Expression] + var expressions: [Expression] = try [fn()] let isTuple = typeof(.comma) while isTuple { - current += 1 - try expressions.append(fn() as! Expression) + current += 1 // consume comma + try expressions.append(fn()) if !typeof(.comma) { break } } - + // Return either a tuple or single expression return isTuple ? TupleLiteral(value: expressions) : expressions[0] } @@ -428,7 +530,6 @@ func parse(tokens: [Token]) throws -> Program { guard current + types.count <= tokens.count else { return false } - return types.enumerated().contains { i, type -> Bool in type != tokens[current + i].type } @@ -436,56 +537,88 @@ func parse(tokens: [Token]) throws -> Program { func parseForStatement() throws -> Statement { let loopVariable = try parseExpressionSequence(primary: true) - if !(loopVariable is Identifier || loopVariable is TupleLiteral) { throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + "Expected identifier/tuple for the loop variable, got \(type(of: loopVariable)) instead" ) } - try expect(type: .in, error: "Expected `in` keyword following loop variable") - let iterable = try parseExpression() - + // Handle optional if condition for filtering + var test: Expression? = nil + if typeof(.if) { + current += 1 // consume if token + test = try parseExpression() + } try expect(type: .closeStatement, error: "Expected closing statement token") - var body: [Statement] = [] - while not(.openStatement, .endFor) { - try body.append(parseAny()) + var defaultBlock: [Statement] = [] + while not(.openStatement, .endFor) && not(.openStatement, .else) { + body.append(try parseAny()) } + if typeof(.openStatement, .else) { + current += 1 // consume {% + try expect(type: .else, error: "Expected else token") + try expect(type: .closeStatement, error: "Expected closing statement token") - if let loopVariable = loopVariable as? Loopvar { - return For(loopvar: loopVariable, iterable: iterable as! Expression, body: body) + while not(.openStatement, .endFor) { + defaultBlock.append(try parseAny()) + } } - - throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + return For( + loopvar: loopVariable, + iterable: iterable, + body: body, + defaultBlock: defaultBlock, + test: test ) } + func parseMacroStatement() throws -> Macro { + let name = try parsePrimaryExpression() + if !(name is Identifier) { + throw JinjaError.syntax("Expected identifier following macro statement") + } + let args = try parseArgs() + try expect(type: .closeStatement, error: "Expected closing statement token") + var body: [Statement] = [] + while not(.openStatement, .endMacro) { + body.append(try parseAny()) + } + return Macro(name: name as! Identifier, args: args, body: body) + } + func parseJinjaStatement() throws -> Statement { + // Consume {% %} tokens try expect(type: .openStatement, error: "Expected opening statement token") var result: Statement switch tokens[current].type { case .set: - current += 1 + current += 1 // consume 'set' token result = try parseSetStatement() - try expect(type: .closeStatement, error: "Expected closing statement token") case .if: - current += 1 + current += 1 // consume 'if' token result = try parseIfStatement() try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endIf, error: "Expected endif token") try expect(type: .closeStatement, error: "Expected %} token") + case .macro: + current += 1 // consume 'macro' token + result = try parseMacroStatement() + try expect(type: .openStatement, error: "Expected {% token") + try expect(type: .endMacro, error: "Expected endmacro token") + try expect(type: .closeStatement, error: "Expected %} token") case .for: - current += 1 + current += 1 // consume 'for' token result = try parseForStatement() try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endFor, error: "Expected endfor token") try expect(type: .closeStatement, error: "Expected %} token") default: - throw JinjaError.syntax("Unknown statement type: \(tokens[current].type)") + // Handle expressions within statements + result = try parseExpression() + try expect(type: .closeStatement, error: "Expected closing statement token") } return result @@ -493,11 +626,8 @@ func parse(tokens: [Token]) throws -> Program { func parseJinjaExpression() throws -> Statement { try expect(type: .openExpression, error: "Expected opening expression token") - let result = try parseExpression() - try expect(type: .closeExpression, error: "Expected closing expression token") - return result } diff --git a/Sources/Runtime.swift b/Sources/Runtime.swift index 73a0d48..c0f6fb4 100644 --- a/Sources/Runtime.swift +++ b/Sources/Runtime.swift @@ -6,11 +6,12 @@ // import Foundation +import OrderedCollections protocol RuntimeValue { - associatedtype T - var value: T { get set } + associatedtype ValueType + var value: ValueType { get } var builtins: [String: any RuntimeValue] { get set } func bool() -> Bool @@ -21,7 +22,12 @@ struct NumericValue: RuntimeValue { var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { - self.value as? Int != 0 + if let intValue = self.value as? Int { + return intValue != 0 + } else if let doubleValue = self.value as? Double { + return doubleValue != 0.0 + } + return false } } @@ -35,7 +41,7 @@ struct BooleanValue: RuntimeValue { } struct NullValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -44,7 +50,7 @@ struct NullValue: RuntimeValue { } struct UndefinedValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -64,56 +70,85 @@ struct ArrayValue: RuntimeValue { } func bool() -> Bool { - !self.value.isEmpty + return !self.value.isEmpty } } struct TupleValue: RuntimeValue { - var value: ArrayValue + var value: [any RuntimeValue] var builtins: [String: any RuntimeValue] = [:] + init(value: [any RuntimeValue]) { + self.value = value + self.builtins["length"] = FunctionValue(value: { _, _ in + NumericValue(value: value.count) + }) + } + func bool() -> Bool { - self.value.bool() + !self.value.isEmpty } } -struct ObjectValue: RuntimeValue { - var value: [String: any RuntimeValue] - var builtins: [String: any RuntimeValue] = [:] +struct ObjectValue: RuntimeValue, Sequence { + var storage: OrderedDictionary + var builtins: [String: any RuntimeValue] - init(value: [String: any RuntimeValue]) { - self.value = value + var value: [String: any RuntimeValue] { Dictionary(uniqueKeysWithValues: storage.map { ($0, $1) }) } + var orderedKeys: [String] { Array(storage.keys) } + + init(value: [String: any RuntimeValue], keyOrder: [String]? = nil) { + // If keyOrder is provided, use it; otherwise, maintain the original order from the dictionary + let orderedKeys = keyOrder ?? Array(value.keys) + let orderedPairs = orderedKeys.compactMap { key in + value[key].map { (key, $0) } + } + + // Recursively create OrderedDictionary for nested objects + let processedPairs = orderedPairs.map { key, value -> (String, any RuntimeValue) in + if let objectValue = value as? ObjectValue { + // Already an ObjectValue, use it directly + return (key, objectValue) + } else if let dictValue = value.value as? [String: any RuntimeValue] { + // If the value contains a dictionary, convert it to ObjectValue + return (key, ObjectValue(value: dictValue)) + } + return (key, value) + } + + self.storage = OrderedDictionary(uniqueKeysWithValues: processedPairs) self.builtins = [ "get": FunctionValue(value: { args, _ in - if let key = args[0] as? StringValue { - if let value = value.first(where: { $0.0 == key.value }) { - return value as! (any RuntimeValue) - } else if args.count > 1 { - return args[1] - } else { - return NullValue() - } - } else { - throw JinjaError.runtime("Object key must be a string: got \(type(of:args[0]))") + guard let key = args[0] as? StringValue else { + throw JinjaError.runtime("Object key must be a string: got \(type(of: args[0]))") } + if let value = value[key.value] { + return value + } else if args.count > 1 { + return args[1] + } + return NullValue() }), "items": FunctionValue(value: { _, _ in - var items: [ArrayValue] = [] - for (k, v) in value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) - } - return items as! (any RuntimeValue) + ArrayValue( + value: orderedPairs.map { key, value in + ArrayValue(value: [StringValue(value: key), value]) + } + ) }), ] } + mutating func setValue(key: String, value: any RuntimeValue) { + storage[key] = value + } + func bool() -> Bool { - !self.value.isEmpty + !storage.isEmpty + } + + func makeIterator() -> OrderedDictionary.Iterator { + return storage.makeIterator() } } @@ -136,22 +171,24 @@ struct StringValue: RuntimeValue { "upper": FunctionValue(value: { _, _ in StringValue(value: value.uppercased()) }), - "lower": FunctionValue(value: { _, _ in StringValue(value: value.lowercased()) }), - "strip": FunctionValue(value: { _, _ in StringValue(value: value.trimmingCharacters(in: .whitespacesAndNewlines)) }), - "title": FunctionValue(value: { _, _ in - StringValue(value: value.capitalized) + StringValue(value: value.titleCase()) }), - "length": FunctionValue(value: { _, _ in NumericValue(value: value.count) }), + "rstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "\\s+$", with: "", options: .regularExpression)) + }), + "lstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "^\\s+", with: "", options: .regularExpression)) + }), ] } @@ -175,23 +212,24 @@ struct Interpreter { var result = "" for statement in statements { let lastEvaluated = try self.evaluate(statement: statement, environment: environment) - if !(lastEvaluated is NullValue), !(lastEvaluated is UndefinedValue) { - if let value = lastEvaluated.value as? String { - result += value + if let stringValue = lastEvaluated as? StringValue { + result += stringValue.value + } else if let numericValue = lastEvaluated as? NumericValue { + result += String(describing: numericValue.value) + } else if let booleanValue = lastEvaluated as? BooleanValue { + result += String(booleanValue.value) + } else if let arrayValue = lastEvaluated as? ArrayValue { + // Convert array to JSON string + result += try toJSON(arrayValue) + } else if let objectValue = lastEvaluated as? ObjectValue { + // Convert object to JSON string + result += try toJSON(objectValue) } else { - switch lastEvaluated.value { - case let value as Int: - result += String(value) - case let value as String: - result += value - default: - throw JinjaError.runtime("Unknown value type:\(type(of: lastEvaluated.value))") - } + throw JinjaError.runtime("Cannot convert to string: \(type(of: lastEvaluated))") } } } - return StringValue(value: result) } @@ -206,229 +244,543 @@ struct Interpreter { try environment.setVariable(name: variableName, value: rhs) } else if let member = node.assignee as? MemberExpression { let object = try self.evaluate(statement: member.object, environment: environment) - - if var object = object as? ObjectValue { - if let property = member.property as? Identifier { - object.value[property.value] = rhs - } else { - throw JinjaError.runtime("Cannot assign to member with non-identifier property") - } - } else { + guard var objectValue = object as? ObjectValue else { throw JinjaError.runtime("Cannot assign to member of non-object") } + guard let property = member.property as? Identifier else { + throw JinjaError.runtime("Cannot assign to member with non-identifier property") + } + // Modify the copy + objectValue.setValue(key: property.value, value: rhs) + // Update the environment with the modified copy + if let parentIdentifier = member.object as? Identifier { + try environment.setVariable(name: parentIdentifier.value, value: objectValue) + } else { + throw JinjaError.runtime("Cannot assign to computed member expression") + } } else { - throw JinjaError.runtime("Invalid assignee type: \(type(of: node.assignee))") + throw JinjaError.runtime("Invalid LHS inside assignment expression: \(node.assignee)") } - return NullValue() } func evaluateIf(node: If, environment: Environment) throws -> StringValue { let test = try self.evaluate(statement: node.test, environment: environment) - return try self.evaluateBlock(statements: test.bool() ? node.body : node.alternate, environment: environment) } func evaluateIdentifier(node: Identifier, environment: Environment) throws -> any RuntimeValue { - environment.lookupVariable(name: node.value) + let value = environment.lookupVariable(name: node.value) + return value } - func evaluateFor(node: For, environment: Environment) throws -> any RuntimeValue { + func evaluateFor(node: For, environment: Environment) throws -> StringValue { + // Scope for the for loop let scope = Environment(parent: environment) - - let iterable = try self.evaluate(statement: node.iterable, environment: scope) - var result = "" - if let iterable = iterable as? ArrayValue { - for i in 0 ..< iterable.value.count { - let loop: [String: any RuntimeValue] = [ - "index": NumericValue(value: i + 1), - "index0": NumericValue(value: i), - "revindex": NumericValue(value: iterable.value.count - i), - "revindex0": NumericValue(value: iterable.value.count - i - 1), - "first": BooleanValue(value: i == 0), - "last": BooleanValue(value: i == iterable.value.count - 1), - "length": NumericValue(value: iterable.value.count), - "previtem": i > 0 ? iterable.value[i - 1] : UndefinedValue(), - "nextitem": i < iterable.value.count - 1 ? iterable.value[i + 1] : UndefinedValue(), - ] - - try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) - - let current = iterable.value[i] - + let test: Expression? + let iterable: any RuntimeValue + if let selectExpression = node.iterable as? SelectExpression { + iterable = try self.evaluate(statement: selectExpression.iterable, environment: scope) + test = selectExpression.test + } else { + iterable = try self.evaluate(statement: node.iterable, environment: scope) + test = nil + } + var items: [any RuntimeValue] = [] + var scopeUpdateFunctions: [(Environment) throws -> Void] = [] + // Keep track of the indices of the original iterable that passed the test + var filteredIndices: [Int] = [] + var originalIndex = 0 + // Handle ArrayValue + if let arrayIterable = iterable as? ArrayValue { + for current in arrayIterable.value { + let loopScope = Environment(parent: scope) + var scopeUpdateFunction: (Environment) throws -> Void if let identifier = node.loopvar as? Identifier { - try scope.setVariable(name: identifier.value, value: current) - } else { - } - - switch node.loopvar { - case let identifier as Identifier: - try scope.setVariable(name: identifier.value, value: current) - case let tupleLiteral as TupleLiteral: - if let current = current as? ArrayValue { - if tupleLiteral.value.count != current.value.count { - throw JinjaError.runtime( - "Too \(tupleLiteral.value.count > current.value.count ? "few" : "many") items to unpack" - ) - } - - for j in 0 ..< tupleLiteral.value.count { - if let identifier = tupleLiteral.value[j] as? Identifier { - try scope.setVariable(name: identifier.value, value: current.value[j]) - } else { - throw JinjaError.runtime( - "Cannot unpack non-identifier type: \(type(of:tupleLiteral.value[j]))" - ) + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + guard let currentArray = current as? ArrayValue else { + throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of: current))") + } + if tupleLiteral.value.count != currentArray.value.count { + throw JinjaError.runtime( + "Too \(tupleLiteral.value.count > currentArray.value.count ? "few" : "many") items to unpack" + ) + } + scopeUpdateFunction = { scope in + for (i, value) in tupleLiteral.value.enumerated() { + guard let identifier = value as? Identifier else { + throw JinjaError.runtime("Cannot unpack non-identifier type: \(type(of: value))") } + try scope.setVariable(name: identifier.value, value: currentArray.value[i]) } - } else { - throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))") } - default: - throw JinjaError.syntaxNotSupported(String(describing: node.loopvar)) + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") } - - let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) - result += evaluated.value + // Evaluate the test before adding the item + if let test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 + } + // Handle StringValue as a special case + } else if let stringIterable = iterable as? StringValue { + // Treat the string as an iterable of characters + for char in stringIterable.value { + let current = StringValue(value: String(char)) + let loopScope = Environment(parent: scope) + var scopeUpdateFunction: (Environment) throws -> Void + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") + } + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 + } + // Handle ObjectValue (dictionary) + } else if let objectIterable = iterable as? ObjectValue { + // Treat the dictionary as an iterable of key-value pairs + for (key, value) in objectIterable { + let current = ArrayValue(value: [StringValue(value: key), value]) + let loopScope = Environment(parent: scope) + var scopeUpdateFunction: (Environment) throws -> Void + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in + try scope.setVariable(name: identifier.value, value: current) + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + // Support unpacking of key-value pairs into two variables + if tupleLiteral.value.count != 2 { + throw JinjaError.runtime( + "Cannot unpack dictionary entry: expected 2 variables, got \(tupleLiteral.value.count)" + ) + } + guard let keyIdentifier = tupleLiteral.value[0] as? Identifier else { + throw JinjaError.runtime( + "Cannot unpack dictionary entry into non-identifier: \(type(of: tupleLiteral.value[0]))" + ) + } + guard let valueIdentifier = tupleLiteral.value[1] as? Identifier else { + throw JinjaError.runtime( + "Cannot unpack dictionary entry into non-identifier: \(type(of: tupleLiteral.value[1]))" + ) + } + scopeUpdateFunction = { scope in + try scope.setVariable(name: keyIdentifier.value, value: StringValue(value: key)) + try scope.setVariable(name: valueIdentifier.value, value: value) + } + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") + } + // Evaluate the test before adding the item + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + originalIndex += 1 + continue + } + } + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + filteredIndices.append(originalIndex) + originalIndex += 1 } } else { - throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of:iterable))") + throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of: iterable))") + } + var result = "" + var noIteration = true + for i in 0 ..< items.count { + // Get the previous and next items that passed the filter + let previousIndex = filteredIndices.firstIndex(of: filteredIndices[i])! - 1 + let nextIndex = filteredIndices.firstIndex(of: filteredIndices[i])! + 1 + let previtem: any RuntimeValue + if previousIndex >= 0 { + let previousFilteredIndex = filteredIndices[previousIndex] + if let arrayIterable = iterable as? ArrayValue { + previtem = arrayIterable.value[previousFilteredIndex] + } else if let stringIterable = iterable as? StringValue { + let index = stringIterable.value.index( + stringIterable.value.startIndex, + offsetBy: previousFilteredIndex + ) + previtem = StringValue(value: String(stringIterable.value[index])) + } else if let objectIterable = iterable as? ObjectValue { + let (key, value) = objectIterable.storage.elements[previousFilteredIndex] + previtem = ArrayValue(value: [StringValue(value: key), value]) + } else { + previtem = UndefinedValue() + } + } else { + previtem = UndefinedValue() + } + let nextitem: any RuntimeValue + if nextIndex < filteredIndices.count { + let nextFilteredIndex = filteredIndices[nextIndex] + if let arrayIterable = iterable as? ArrayValue { + nextitem = arrayIterable.value[nextFilteredIndex] + } else if let stringIterable = iterable as? StringValue { + let index = stringIterable.value.index(stringIterable.value.startIndex, offsetBy: nextFilteredIndex) + nextitem = StringValue(value: String(stringIterable.value[index])) + } else if let objectIterable = iterable as? ObjectValue { + let (key, value) = objectIterable.storage.elements[nextFilteredIndex] + nextitem = ArrayValue(value: [StringValue(value: key), value]) + } else { + nextitem = UndefinedValue() + } + } else { + nextitem = UndefinedValue() + } + let loop: [String: any RuntimeValue] = [ + "index": NumericValue(value: i + 1), + "index0": NumericValue(value: i), + "revindex": NumericValue(value: items.count - i), + "revindex0": NumericValue(value: items.count - i - 1), + "first": BooleanValue(value: i == 0), + "last": BooleanValue(value: i == items.count - 1), + "length": NumericValue(value: items.count), + "previtem": previtem, + "nextitem": nextitem, + ] + try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) + try scopeUpdateFunctions[i](scope) + let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) + result += evaluated.value + noIteration = false + } + if noIteration { + let defaultEvaluated = try self.evaluateBlock(statements: node.defaultBlock, environment: scope) + result += defaultEvaluated.value } - return StringValue(value: result) } func evaluateBinaryExpression(node: BinaryExpression, environment: Environment) throws -> any RuntimeValue { let left = try self.evaluate(statement: node.left, environment: environment) - + let right = try self.evaluate(statement: node.right, environment: environment) + // Handle 'or' + if node.operation.value == "or" { + if left.bool() { + return left + } else { + return right + } + } + // Handle 'and' if node.operation.value == "and" { - return left.bool() ? try self.evaluate(statement: node.right, environment: environment) : left - } else if node.operation.value == "or" { - return left.bool() ? left : try self.evaluate(statement: node.right, environment: environment) + if !left.bool() { + return left + } else { + return right + } } - - let right = try self.evaluate(statement: node.right, environment: environment) - + // == if node.operation.value == "==" { - switch left.value { - case let value as String: - return BooleanValue(value: value == right.value as! String) - case let value as Int: - return BooleanValue(value: value == right.value as! Int) - case let value as Bool: - return BooleanValue(value: value == right.value as! Bool) - default: - throw JinjaError.runtime( - "Unknown left value type:\(type(of: left.value)), right value type:\(type(of: right.value))" - ) + // Handle array indexing for right operand + if let memberExpr = node.right as? MemberExpression, + let arrayValue = try self.evaluate(statement: memberExpr.object, environment: environment) + as? ArrayValue, + let indexExpr = memberExpr.property as? NumericLiteral, + let index = indexExpr.value as? Int + { + + // Handle negative indices + let actualIndex = index < 0 ? arrayValue.value.count + index : index + if actualIndex >= 0 && actualIndex < arrayValue.value.count { + let rightValue = arrayValue.value[actualIndex] + return BooleanValue(value: try areEqual(left, rightValue)) + } } - } else if node.operation.value == "!=" { - if type(of: left) != type(of: right) { + + return BooleanValue(value: try areEqual(left, right)) + } + // != + if node.operation.value == "!=" { + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value != right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt != rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble != rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) != rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble != Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for inequality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value != right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: false) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: false) + } else if type(of: left) == type(of: right) { return BooleanValue(value: true) } else { - return BooleanValue(value: left.value as! AnyHashable != right.value as! AnyHashable) + return BooleanValue(value: true) } } - if left is UndefinedValue || right is UndefinedValue { throw JinjaError.runtime("Cannot perform operation on undefined values") } else if left is NullValue || right is NullValue { throw JinjaError.runtime("Cannot perform operation on null values") } else if let left = left as? NumericValue, let right = right as? NumericValue { switch node.operation.value { - case "+": throw JinjaError.syntaxNotSupported("+") - case "-": throw JinjaError.syntaxNotSupported("-") - case "*": throw JinjaError.syntaxNotSupported("*") - case "/": throw JinjaError.syntaxNotSupported("/") + case "+": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt + rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble + rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) + rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble + Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for addition") + } + case "-": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt - rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble - rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) - rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble - Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for subtraction") + } + case "*": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt * rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble * rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) * rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble * Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for multiplication") + } + case "/": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt / rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble / rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) / rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble / Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for division") + } case "%": - switch left.value { - case is Int: - return NumericValue(value: left.value as! Int % (right.value as! Int)) - default: - throw JinjaError.runtime("Unknown value type:\(type(of: left.value))") - } - case "<": throw JinjaError.syntaxNotSupported("<") - case ">": throw JinjaError.syntaxNotSupported(">") - case ">=": throw JinjaError.syntaxNotSupported(">=") - case "<=": throw JinjaError.syntaxNotSupported("<=") + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt % rightInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for modulus") + } + case "<": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt < rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble < rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) < rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble < Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than comparison") + } + case ">": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt > rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble > rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) > rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble > Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than comparison") + } + case ">=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt >= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble >= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) >= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble >= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than or equal to comparison") + } + case "<=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt <= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble <= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) <= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble <= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than or equal to comparison") + } default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if left is ArrayValue && right is ArrayValue { + } else if let left = left as? ArrayValue, let right = right as? ArrayValue { switch node.operation.value { - case "+": break + case "+": + return ArrayValue(value: left.value + right.value) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if right is ArrayValue { - throw JinjaError.syntaxNotSupported("right is ArrayValue") - } - - if left is StringValue || right is StringValue { - switch node.operation.value { - case "+": - var rightValue = "" - var leftValue = "" - switch right.value { - case let value as String: - rightValue = value - case let value as Int: - rightValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown right value type:\(type(of: right.value))") - } - - switch left.value { - case let value as String: - leftValue = value - case let value as Int: - leftValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown left value type:\(type(of: left.value))") - } - - return StringValue(value: leftValue + rightValue) - default: - break + } else if let right = right as? ArrayValue { + let member: Bool + if let left = left as? StringValue { + member = right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + } else if let left = left as? NumericValue { + member = right.value.contains { + if let item = $0 as? NumericValue { + return item.value as! Int == left.value as! Int + } + return false + } + } else if let left = left as? BooleanValue { + member = right.value.contains { + if let item = $0 as? BooleanValue { + return item.value == left.value + } + return false + } + } else { + throw JinjaError.runtime("Unsupported left type for 'in'/'not in' operation with ArrayValue") } - } - - if let left = left as? StringValue, let right = right as? StringValue { switch node.operation.value { case "in": - return BooleanValue(value: right.value.contains(left.value)) + return BooleanValue(value: member) case "not in": - return BooleanValue(value: !right.value.contains(left.value)) + return BooleanValue(value: !member) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } } - - if left is StringValue, right is ObjectValue { + if let left = left as? StringValue { switch node.operation.value { + case "+": + let rightValue: String + if let rightString = right as? StringValue { + rightValue = rightString.value + } else if let rightNumeric = right as? NumericValue { + rightValue = String(describing: rightNumeric.value) + } else if let rightBoolean = right as? BooleanValue { + rightValue = String(rightBoolean.value) + } else if right is UndefinedValue { + rightValue = "" + } else { + throw JinjaError.runtime("Unsupported right operand type for string concatenation") + } + return StringValue(value: left.value + rightValue) case "in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime("Right operand of 'in' must be a StringValue, ArrayValue, or ObjectValue") } case "not in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: !rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: !right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: !right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: !right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime( + "Right operand of 'not in' must be a StringValue, ArrayValue, or ObjectValue" + ) } + default: + break + } + } else if let right = right as? StringValue { + if node.operation.value == "+" { + if let leftString = left as? StringValue { + return StringValue(value: leftString.value + right.value) + } else if let leftNumeric = left as? NumericValue { + return StringValue(value: String(describing: leftNumeric.value) + right.value) + } else if let leftBoolean = left as? BooleanValue { + return StringValue(value: String(leftBoolean.value) + right.value) + } else { + throw JinjaError.runtime("Unsupported left operand type for string concatenation") + } + } + } + if let left = left as? StringValue, let right = right as? ObjectValue { + switch node.operation.value { + case "in": + return BooleanValue(value: right.value.keys.contains(left.value)) + case "not in": + return BooleanValue(value: !right.value.keys.contains(left.value)) default: throw JinjaError.runtime( "Unsupported operation '\(node.operation.value)' between StringValue and ObjectValue" ) } } - throw JinjaError.syntax( "Unknown operator '\(node.operation.value)' between \(type(of:left)) and \(type(of:right))" ) @@ -442,49 +794,42 @@ struct Interpreter { if !(object is ArrayValue || object is StringValue) { throw JinjaError.runtime("Slice object must be an array or string") } - let start = try self.evaluate(statement: expr.start, environment: environment) let stop = try self.evaluate(statement: expr.stop, environment: environment) let step = try self.evaluate(statement: expr.step, environment: environment) - if !(start is NumericValue || start is UndefinedValue) { throw JinjaError.runtime("Slice start must be numeric or undefined") } - if !(stop is NumericValue || stop is UndefinedValue) { throw JinjaError.runtime("Slice stop must be numeric or undefined") } - if !(step is NumericValue || step is UndefinedValue) { throw JinjaError.runtime("Slice step must be numeric or undefined") } - if let object = object as? ArrayValue { return ArrayValue( value: slice( object.value, - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int ) ) } else if let object = object as? StringValue { return StringValue( value: slice( - Array(arrayLiteral: object.value), - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int - ).joined() + Array(object.value), + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int + ).map { String($0) }.joined() ) } - throw JinjaError.runtime("Slice object must be an array or string") } func evaluateMemberExpression(expr: MemberExpression, environment: Environment) throws -> any RuntimeValue { let object = try self.evaluate(statement: expr.object, environment: environment) - var property: any RuntimeValue if expr.computed { if let property = expr.property as? SliceExpression { @@ -495,7 +840,6 @@ struct Interpreter { } else { property = StringValue(value: (expr.property as! Identifier).value) } - var value: (any RuntimeValue)? if let object = object as? ObjectValue { if let property = property as? StringValue { @@ -503,34 +847,54 @@ struct Interpreter { } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } - } else if object is ArrayValue || object is StringValue { + } else if let object = object as? ArrayValue { if let property = property as? NumericValue { - if let object = object as? ArrayValue { - let index = property.value as! Int - if index >= 0 { - value = object.value[index] + if let index = property.value as? Int { + let actualIndex = index < 0 ? object.value.count + index : index + if actualIndex >= 0 && actualIndex < object.value.count { + value = object.value[actualIndex] } else { - value = object.value[object.value.count + index] + value = UndefinedValue() } - } else if let object = object as? StringValue { - let index = object.value.index(object.value.startIndex, offsetBy: property.value as! Int) - value = StringValue(value: String(object.value[index])) + } else { + throw JinjaError.runtime("Array index must be an integer") } } else if let property = property as? StringValue { value = object.builtins[property.value] } else { throw JinjaError.runtime( - "Cannot access property with non-string/non-number: got \(type(of:property))" + "Cannot access property with non-string/non-number: got \(type(of: property))" + ) + } + } else if let object = object as? StringValue { + if let property = property as? NumericValue { + if let index = property.value as? Int { + if index >= 0 && index < object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: index) + value = StringValue(value: String(object.value[strIndex])) + } else if index < 0 && index >= -object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: object.value.count + index) + value = StringValue(value: String(object.value[strIndex])) + } else { + value = UndefinedValue() + } + } else { + throw JinjaError.runtime("String index must be an integer") + } + } else if let property = property as? StringValue { + value = object.builtins[property.value] + } else { + throw JinjaError.runtime( + "Cannot access property with non-string/non-number: got \(type(of: property))" ) } } else { if let property = property as? StringValue { - value = object.builtins[property.value]! + value = object.builtins[property.value] } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } } - if let value { return value } else { @@ -540,7 +904,6 @@ struct Interpreter { func evaluateUnaryExpression(node: UnaryExpression, environment: Environment) throws -> any RuntimeValue { let argument = try self.evaluate(statement: node.argument, environment: environment) - switch node.operation.value { case "not": return BooleanValue(value: !argument.bool()) @@ -552,7 +915,6 @@ struct Interpreter { func evaluateCallExpression(expr: CallExpression, environment: Environment) throws -> any RuntimeValue { var args: [any RuntimeValue] = [] var kwargs: [String: any RuntimeValue] = [:] - for argument in expr.args { if let argument = argument as? KeywordArgumentExpression { kwargs[argument.key.value] = try self.evaluate(statement: argument.value, environment: environment) @@ -560,13 +922,10 @@ struct Interpreter { try args.append(self.evaluate(statement: argument, environment: environment)) } } - - if kwargs.count > 0 { + if !kwargs.isEmpty { args.append(ObjectValue(value: kwargs)) } - let fn = try self.evaluate(statement: expr.callee, environment: environment) - if let fn = fn as? FunctionValue { return try fn.value(args, environment) } else { @@ -574,89 +933,108 @@ struct Interpreter { } } - func evaluateFilterExpression(node: FilterExpression, environment: Environment) throws -> any RuntimeValue { - let operand = try evaluate(statement: node.operand, environment: environment) - - if let identifier = node.filter as? Identifier { - if let arrayValue = operand as? ArrayValue { - switch identifier.value { - case "list": - return arrayValue - case "first": - return arrayValue.value.first ?? UndefinedValue() - case "last": - return arrayValue.value.last ?? UndefinedValue() - case "length": - return NumericValue(value: arrayValue.value.count) - case "reverse": - return ArrayValue(value: arrayValue.value.reversed()) - case "sort": - throw JinjaError.todo("TODO: ArrayValue filter sort") - default: - throw JinjaError.runtime("Unknown ArrayValue filter: \(identifier.value)") - } - } else if let stringValue = operand as? StringValue { - switch identifier.value { - case "length": - return NumericValue(value: stringValue.value.count) - case "upper": - return StringValue(value: stringValue.value.uppercased()) - case "lower": - return StringValue(value: stringValue.value.lowercased()) - case "title": - return StringValue(value: stringValue.value.capitalized) - case "capitalize": - return StringValue(value: stringValue.value.capitalized) - case "trim": - return StringValue(value: stringValue.value.trimmingCharacters(in: .whitespacesAndNewlines)) - default: - throw JinjaError.runtime("Unknown StringValue filter: \(identifier.value)") - } - } else if let numericValue = operand as? NumericValue { - switch identifier.value { - case "abs": - return NumericValue(value: abs(numericValue.value as! Int32)) - default: - throw JinjaError.runtime("Unknown NumericValue filter: \(identifier.value)") - } - } else if let objectValue = operand as? ObjectValue { - switch identifier.value { - case "items": - var items: [ArrayValue] = [] - for (k, v) in objectValue.value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) - } - return items as! (any RuntimeValue) - case "length": - return NumericValue(value: objectValue.value.count) - default: - throw JinjaError.runtime("Unknown ObjectValue filter: \(identifier.value)") - } + func evaluateFilterExpression(node: FilterExpression, environment: Environment, whitespaceControl: Bool) throws + -> any RuntimeValue + { + let operand = try self.evaluate(statement: node.operand, environment: environment) + let filterName = node.filter.value + guard let filter = environment.filters[filterName] else { + throw JinjaError.runtime("No filter named '\(filterName)'") + } + // Evaluate positional arguments + let evaluatedPositionalArgs = try node.args.map { arg in + try self.evaluate(statement: arg, environment: environment) + } + // Create args array starting with operand + var args: [any RuntimeValue] = [operand] + args.append(contentsOf: evaluatedPositionalArgs) + // If we have keyword arguments, add them as a final ObjectValue argument + if !node.kwargs.isEmpty { + var kwargs: [String: any RuntimeValue] = [:] + for kwarg in node.kwargs { + kwargs[kwarg.key.value] = try self.evaluate(statement: kwarg.value, environment: environment) } - - throw JinjaError.runtime("Cannot apply filter \(operand.value) to type: \(type(of:operand))") + args.append(ObjectValue(value: kwargs)) } - - throw JinjaError.runtime("Unknown filter: \(node.filter)") + return try filter(args, environment) } func evaluateTestExpression(node: TestExpression, environment: Environment) throws -> any RuntimeValue { let operand = try self.evaluate(statement: node.operand, environment: environment) - - if let testFunction = environment.tests[node.test.value] { - let result = try testFunction(operand) - return BooleanValue(value: node.negate ? !result : result) - } else { + guard let testFunction = environment.tests[node.test.value] else { throw JinjaError.runtime("Unknown test: \(node.test.value)") } + let result = try testFunction(operand) + return BooleanValue(value: node.negate ? !result : result) + } + + func evaluateMacro(node: Macro, environment: Environment) throws -> NullValue { + try environment.setVariable( + name: node.name.value, + value: FunctionValue(value: { args, scope in + let macroScope = Environment(parent: scope) + var args = args + var kwargs: [String: any RuntimeValue] = [:] + if let lastArg = args.last, let keywordArgsValue = lastArg as? KeywordArgumentsValue { + kwargs = keywordArgsValue.value + args.removeLast() + } + for i in 0 ..< node.args.count { + let nodeArg = node.args[i] + let passedArg = args.count > i ? args[i] : nil + + if let identifier = nodeArg as? Identifier { + if passedArg == nil { + if let defaultValue = kwargs[identifier.value] { + try macroScope.setVariable(name: identifier.value, value: defaultValue) + } else { + throw JinjaError.runtime("Missing argument: \(identifier.value)") + } + } else { + try macroScope.setVariable(name: identifier.value, value: passedArg!) + } + } else if let kwarg = nodeArg as? KeywordArgumentExpression { + let value = + try kwargs[kwarg.key.value] + ?? (passedArg ?? (try self.evaluate(statement: kwarg.value, environment: macroScope))) + + try macroScope.setVariable(name: kwarg.key.value, value: value) + } else { + throw JinjaError.runtime("Unknown argument type: \(type(of: nodeArg))") + } + } + return try self.evaluateBlock(statements: node.body, environment: macroScope) + }) + ) + return NullValue() } - func evaluate(statement: Statement?, environment: Environment) throws -> any RuntimeValue { + func evaluateArguments( + args: [Expression], + environment: Environment + ) throws -> ([any RuntimeValue], [String: any RuntimeValue]) { + var positionalArguments: [any RuntimeValue] = [] + var keywordArguments: [String: any RuntimeValue] = [:] + for argument in args { + if let keywordArgument = argument as? KeywordArgumentExpression { + keywordArguments[keywordArgument.key.value] = try self.evaluate( + statement: keywordArgument.value, + environment: environment + ) + } else { + if !keywordArguments.isEmpty { + throw JinjaError.runtime("Positional arguments must come before keyword arguments") + } + positionalArguments.append(try self.evaluate(statement: argument, environment: environment)) + } + } + + return (positionalArguments, keywordArguments) + } + + func evaluate(statement: Statement?, environment: Environment, whitespaceControl: Bool = false) throws + -> any RuntimeValue + { if let statement { switch statement { case let statement as Program: @@ -678,15 +1056,41 @@ struct Interpreter { case let statement as UnaryExpression: return try self.evaluateUnaryExpression(node: statement, environment: environment) case let statement as NumericLiteral: - return NumericValue(value: statement.value) + if let intValue = statement.value as? Int { + return NumericValue(value: intValue) + } else if let doubleValue = statement.value as? Double { + return NumericValue(value: doubleValue) + } else { + throw JinjaError.runtime("Invalid numeric literal value") + } case let statement as CallExpression: return try self.evaluateCallExpression(expr: statement, environment: environment) case let statement as BoolLiteral: return BooleanValue(value: statement.value) case let statement as FilterExpression: - return try self.evaluateFilterExpression(node: statement, environment: environment) + return try self.evaluateFilterExpression( + node: statement, + environment: environment, + whitespaceControl: whitespaceControl + ) case let statement as TestExpression: return try self.evaluateTestExpression(node: statement, environment: environment) + case let statement as ArrayLiteral: + return ArrayValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as TupleLiteral: + return TupleValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as ObjectLiteral: + var mapping: [String: any RuntimeValue] = [:] + for (key, value) in statement.value { + mapping[key] = try self.evaluate(statement: value, environment: environment) + } + return ObjectValue(value: mapping) + case let statement as Macro: + return try self.evaluateMacro(node: statement, environment: environment) case is NullLiteral: return NullValue() default: diff --git a/Sources/Template.swift b/Sources/Template.swift index 8f78efe..04a93c5 100644 --- a/Sources/Template.swift +++ b/Sources/Template.swift @@ -15,7 +15,7 @@ public struct Template { self.parsed = try parse(tokens: tokens) } - public func render(_ items: [String: Any]) throws -> String { + public func render(_ items: [String: Any?]) throws -> String { let env = Environment() try env.set(name: "false", value: false) @@ -30,7 +30,9 @@ public struct Template { try env.set(name: "range", value: range) for (key, value) in items { - try env.set(name: key, value: value) + if let value { + try env.set(name: key, value: value) + } } let interpreter = Interpreter(env: env) diff --git a/Sources/Utilities.swift b/Sources/Utilities.swift index c01870b..7017acb 100644 --- a/Sources/Utilities.swift +++ b/Sources/Utilities.swift @@ -21,7 +21,6 @@ func slice(_ array: [T], start: Int? = nil, stop: Int? = nil, step: Int? = 1) let stopValue = stop ?? arrayCount let step = step ?? 1 var slicedArray = [T]() - if step > 0 { let startIndex = startValue < 0 ? max(arrayCount + startValue, 0) : min(startValue, arrayCount) let stopIndex = stopValue < 0 ? max(arrayCount + stopValue, 0) : min(stopValue, arrayCount) @@ -35,6 +34,197 @@ func slice(_ array: [T], start: Int? = nil, stop: Int? = nil, step: Int? = 1) slicedArray.append(array[i]) } } - return slicedArray } + +func toJSON(_ input: any RuntimeValue, indent: Int? = nil, depth: Int = 0, whitespaceControl: Bool = false) throws + -> String +{ + // If whitespaceControl is true, output compact JSON + if whitespaceControl { + switch input { + case is NullValue, is UndefinedValue: + return "null" + case let value as NumericValue: + return String(describing: value.value) + case let value as StringValue: + let escapedValue = value.value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + .replacingOccurrences(of: "\r", with: "\\r") + .replacingOccurrences(of: "\t", with: "\\t") + return "\"\(escapedValue)\"" + case let value as BooleanValue: + return value.value ? "true" : "false" + case let arr as ArrayValue: + let elements = try arr.value.map { + try toJSON($0, indent: nil, depth: 0, whitespaceControl: true) + } + return "[\(elements.joined(separator: ", "))]" + case let obj as ObjectValue: + let pairs = try obj.orderedKeys.map { key in + guard let value = obj.value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try toJSON(value, indent: nil, depth: 0, whitespaceControl: true) + return "\"\(key)\": \(jsonValue)" + } + return "{\(pairs.joined(separator: ", "))}" + default: + throw JinjaError.runtime("Cannot convert to JSON: \(type(of: input))") + } + } + let currentDepth = depth + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: currentDepth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + switch input { + case is NullValue, is UndefinedValue: + return "null" + case let value as NumericValue: + return String(describing: value.value) + case let value as StringValue: + // Properly escape special characters for JSON strings + let escapedValue = value.value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + .replacingOccurrences(of: "\r", with: "\\r") + .replacingOccurrences(of: "\t", with: "\\t") + return "\"\(escapedValue)\"" + case let value as BooleanValue: + return value.value ? "true" : "false" + case let arr as ArrayValue: + let core = try arr.value.map { + try toJSON($0, indent: indent, depth: currentDepth + 1, whitespaceControl: whitespaceControl) + } + if indent != nil && !whitespaceControl { + return "[\(childrenPadding)\(core.joined(separator: ",\(childrenPadding)"))\(basePadding)]" + } else { + return "[\(core.joined(separator: ", "))]" + } + case let obj as ObjectValue: + // Use orderedKeys to maintain insertion order + let pairs = try obj.orderedKeys.map { key in + guard let value = obj.value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try toJSON( + value, + indent: indent, + depth: currentDepth + 1, + whitespaceControl: whitespaceControl + ) + return "\"\(key)\": \(jsonValue)" + } + if indent != nil && !whitespaceControl { + return "{\(childrenPadding)\(pairs.joined(separator: ",\(childrenPadding)"))\(basePadding)}" + } else { + return "{\(pairs.joined(separator: ", "))}" + } + default: + throw JinjaError.runtime("Cannot convert to JSON: \(type(of: input))") + } +} + +// Helper function to convert values to JSON strings +func jsonString(_ value: Any) throws -> String { + let data = try JSONSerialization.data(withJSONObject: value) + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert value to JSON string") + } + return string +} + +extension String { + func titleCase() -> String { + return self.components(separatedBy: .whitespacesAndNewlines) + .map { word in + guard let firstChar = word.first else { return "" } + return String(firstChar).uppercased() + word.dropFirst() + } + .joined(separator: " ") + } + + func indent(_ width: Int, first: Bool = false, blank: Bool = false) -> String { + let indentString = String(repeating: " ", count: width) + return self.components(separatedBy: .newlines) + .enumerated() + .map { (index, line) in + if line.isEmpty && !blank { + return line + } + if index == 0 && !first { + return line + } + return indentString + line + } + .joined(separator: "\n") + } +} + +func stringify(_ value: any RuntimeValue, indent: Int = 4, whitespaceControl: Bool = false) throws -> String { + if let stringValue = value as? StringValue { + return "\"\(stringValue.value)\"" + } else if let numericValue = value as? NumericValue { + return String(describing: numericValue.value) + } else if let booleanValue = value as? BooleanValue { + return booleanValue.value ? "true" : "false" + } else if let objectValue = value as? ObjectValue { + return try toJSON(objectValue, indent: indent, whitespaceControl: whitespaceControl) + } else if let arrayValue = value as? ArrayValue { + return try toJSON(arrayValue, indent: indent, whitespaceControl: whitespaceControl) + } else if value is NullValue { + return "null" + } else if value is UndefinedValue { + return "undefined" + } else { + return "" + } +} + +func areEqual(_ left: any RuntimeValue, _ right: any RuntimeValue) throws -> Bool { + if let leftObj = left as? ObjectValue, let rightObj = right as? ObjectValue { + // Compare ObjectValues by their contents + guard leftObj.storage.keys == rightObj.storage.keys else { + return false + } + + for key in leftObj.storage.keys { + guard let leftValue = leftObj.storage[key], + let rightValue = rightObj.storage[key], + try areEqual(leftValue, rightValue) + else { + return false + } + } + return true + } else if let leftStr = left as? StringValue, let rightStr = right as? StringValue { + return leftStr.value == rightStr.value + } else if let leftNum = left as? NumericValue, let rightNum = right as? NumericValue { + if let leftInt = leftNum.value as? Int, let rightInt = rightNum.value as? Int { + return leftInt == rightInt + } else if let leftDouble = leftNum.value as? Double, let rightDouble = rightNum.value as? Double { + return leftDouble == rightDouble + } + } else if let leftArr = left as? ArrayValue, let rightArr = right as? ArrayValue { + guard leftArr.value.count == rightArr.value.count else { + return false + } + for (leftItem, rightItem) in zip(leftArr.value, rightArr.value) { + guard try areEqual(leftItem, rightItem) else { + return false + } + } + return true + } else if left is NullValue && right is NullValue { + return true + } else if left is UndefinedValue && right is UndefinedValue { + return true + } else if let leftBool = left as? BooleanValue, let rightBool = right as? BooleanValue { + return leftBool.value == rightBool.value + } + // If types don't match, return false + return false +} diff --git a/Tests/ChatTemplateTests.swift b/Tests/ChatTemplateTests.swift deleted file mode 100644 index 4b9ab6b..0000000 --- a/Tests/ChatTemplateTests.swift +++ /dev/null @@ -1,238 +0,0 @@ -// -// ChatTemplateTests.swift -// -// -// Created by John Mai on 2024/3/24. -// - -import XCTest - -@testable import Jinja - -let messages: [[String: String]] = [ - [ - "role": "user", - "content": "Hello, how are you?", - ], - [ - "role": "assistant", - "content": "I'm doing great. How can I help you today?", - ], - [ - "role": "user", - "content": "I'd like to show off how chat templating works!", - ], -] - -let messagesWithSystem: [[String: String]] = - [ - [ - "role": "system", - "content": "You are a friendly chatbot who always responds in the style of a pirate", - ] - ] + messages - -final class ChatTemplateTests: XCTestCase { - struct Test { - let chatTemplate: String - let data: [String: Any] - let target: String - } - - let defaultTemplates: [Test] = [ - Test( - chatTemplate: - "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messages, - "add_generation_prompt": false, - ], - target: - "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" - ), - // facebook/blenderbot-400M-distill - Test( - chatTemplate: - "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" - ), - // facebook/blenderbot_small-90M - Test( - chatTemplate: - "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" - ), - // bigscience/bloom - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "", - ], - target: - "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!" - ), - // EleutherAI/gpt-neox-20b - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // gpt2 - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": messagesWithSystem, - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": messages, - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nDEFAULT_SYSTEM_MESSAGE\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // hf-internal-testing/llama-tokenizer - Test( - chatTemplate: - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - data: [ - "messages": [ - [ - "role": "user", - "content": "<>\nYou are a helpful assistant\n<> Hello, how are you?", - ], - [ - "role": "assistant", - "content": "I'm doing great. How can I help you today?", - ], - [ - "role": "user", - "content": "I'd like to show off how chat templating works!", - ], - ], - "bos_token": "", - "eos_token": "", - "USE_DEFAULT_PROMPT": true, - ], - target: - "[INST] <>\nYou are a helpful assistant\n<> Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" - ), - // openai/whisper-large-v3 - Test( - chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - data: [ - "messages": messages, - "eos_token": "<|endoftext|>", - ], - target: - "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messages, - "add_generation_prompt": true, - ], - target: - "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messagesWithSystem, - "add_generation_prompt": true, - ], - target: - "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" - ), - // Qwen/Qwen1.5-1.8B-Chat - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", - data: [ - "messages": messagesWithSystem - ], - target: - "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!" - ), - // THUDM/chatglm3-6b - Test( - chatTemplate: - "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - data: [ - "messages": messagesWithSystem - ], - target: - "[gMASK]sop<|system|>\n You are a friendly chatbot who always responds in the style of a pirate<|user|>\n Hello, how are you?<|assistant|>\n I\'m doing great. How can I help you today?<|user|>\n I\'d like to show off how chat templating works!" - ), - // google/gemma-2b-it - Test( - chatTemplate: - "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - data: [ - "messages": messages - ], - target: - "user\nHello, how are you?\nmodel\nI\'m doing great. How can I help you today?\nuser\nI\'d like to show off how chat templating works!\n" - ), - // Qwen/Qwen2.5-0.5B-Instruct - Test( - chatTemplate: - "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - data: [ - "messages": messages - ], - target: - "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" - ), - ] - - func testDefaultTemplates() throws { - for test in defaultTemplates { - let template = try Template(test.chatTemplate) - let result = try template.render(test.data) - XCTAssertEqual(result.debugDescription, test.target.debugDescription) - } - } -} diff --git a/Tests/CoreTagTests.swift b/Tests/CoreTagTests.swift new file mode 100644 index 0000000..37147b1 --- /dev/null +++ b/Tests/CoreTagTests.swift @@ -0,0 +1,761 @@ +// +// CoreTagTests.swift +// Jinja +// +// Created by Anthony DePasquale on 07.01.2025. +// + +// Adapted from https://github.com/pallets/jinja/blob/main/tests/test_core_tags.py + +import XCTest +@testable import Jinja + +final class IfConditionTests: XCTestCase { + // MARK: - If Condition Tests + + func testSimpleIf() throws { + let template = try Template("{% if true %}...{% endif %}") + let result = try template.render([:]) + XCTAssertEqual(result, "...") + } + + func testIfElif() throws { + let template = try Template( + """ + {% if false %}XXX{% elif true %}...{% else %}XXX{% endif %} + """ + ) + let result = try template.render([:]) + XCTAssertEqual(result, "...") + } + + func testIfElse() throws { + let template = try Template("{% if false %}XXX{% else %}...{% endif %}") + let result = try template.render([:]) + XCTAssertEqual(result, "...") + } + + func testEmptyIf() throws { + let template = try Template("[{% if true %}{% else %}{% endif %}]") + let result = try template.render([:]) + XCTAssertEqual(result, "[]") + } + + // TODO: Make this test pass + // func testCompleteIf() throws { + // let template = try Template( + // """ + // {% if a %}A{% elif b %}B{% elif c == d %}C{% else %}D{% endif %} + // """ + // ) + // let result = try template.render([ + // "a": 0, + // "b": false, + // "c": 42, + // "d": 42.0, + // ]) + // XCTAssertEqual(result, "C") + // } + + // MARK: - Set Tests + + func testNormalSet() throws { + let template = try Template("{% set foo = 1 %}{{ foo }}") + let result = try template.render([:]) + XCTAssertEqual(result, "1") + } + + // TODO: Make this test pass + // func testBlockSet() throws { + // let template = try Template("{% set foo %}42{% endset %}{{ foo }}") + // let result = try template.render([:]) + // XCTAssertEqual(result, "42") + // } + + func testNamespace() throws { + let template = try Template( + """ + {% set ns = namespace() %}{% set ns.bar = '42' %}{{ ns.bar }} + """ + ) + let result = try template.render([:]) + XCTAssertEqual(result, "42") + } + + // TODO: Make this test pass + // func testNamespaceLoop() throws { + // let template = try Template( + // """ + // {% set ns = namespace(found=false) %}\ + // {% for x in range(4) %}\ + // {% if x == v %}\ + // {% set ns.found = true %}\ + // {% endif %}\ + // {% endfor %}\ + // {{ ns.found }} + // """ + // ) + // + // let result1 = try template.render(["v": 3]) + // XCTAssertEqual(result1, "true") + // + // let result2 = try template.render(["v": 4]) + // XCTAssertEqual(result2, "false") + // } +} + +final class ForLoopTests: XCTestCase { + // MARK: - For Loop Tests + + func testSimpleForLoop() throws { + let template = try Template("{% for item in seq %}{{ item }}{% endfor %}") + let result = try template.render(["seq": Array(0 ... 9)]) + XCTAssertEqual(result, "0123456789") + } + + // TODO: Make this test pass + // func testForLoopWithElse() throws { + // let template = try Template("{% for item in seq %}XXX{% else %}...{% endfor %}") + // let result = try template.render([:]) + // XCTAssertEqual(result, "...") + // } + + func testForLoopElseScopingItem() throws { + let template = try Template("{% for item in [] %}{% else %}{{ item }}{% endfor %}") + let result = try template.render(["item": 42]) + XCTAssertEqual(result, "42") + } + + // TODO: Make this test pass + // func testEmptyBlocks() throws { + // let template = try Template("<{% for item in seq %}{% else %}{% endfor %}>") + // let result = try template.render([:]) + // XCTAssertEqual(result, "<>") + // } + + func testContextVars() throws { + let template = try Template( + """ + {% for item in seq -%} + {{ loop.index }}|{{ loop.index0 }}|{{ loop.revindex }}|{{ + loop.revindex0 }}|{{ loop.first }}|{{ loop.last }}|{{ + loop.length }}###{% endfor %} + """ + ) + + let result = try template.render(["seq": [42, 24]]) + let parts = result.split(separator: "###") + XCTAssertEqual(parts.count, 2) + + let one = String(parts[0]).split(separator: "|") + let two = String(parts[1]).split(separator: "|") + + // First iteration checks + XCTAssertEqual(one[0], "1") // index + XCTAssertEqual(one[1], "0") // index0 + XCTAssertEqual(one[2], "2") // revindex + XCTAssertEqual(one[3], "1") // revindex0 + XCTAssertEqual(one[4], "true") // first + XCTAssertEqual(one[5], "false") // last + XCTAssertEqual(one[6], "2") // length + + // Second iteration checks + XCTAssertEqual(two[0], "2") // index + XCTAssertEqual(two[1], "1") // index0 + XCTAssertEqual(two[2], "1") // revindex + XCTAssertEqual(two[3], "0") // revindex0 + XCTAssertEqual(two[4], "false") // first + XCTAssertEqual(two[5], "true") // last + XCTAssertEqual(two[6], "2") // length + } + + // TODO: Make this test pass + // func testCycling() throws { + // let template = try Template( + // """ + // {% for item in seq %}{{ loop.cycle('<1>', '<2>') }}{% endfor %}\ + // {% for item in seq %}{{ loop.cycle(*through) }}{% endfor %} + // """ + // ) + // let result = try template.render([ + // "seq": Array(0 ... 3), + // "through": ["<1>", "<2>"], + // ]) + // XCTAssertEqual(result, "<1><2><1><2><1><2><1><2>") + // } + + func testLookaround() throws { + let template = try Template( + """ + {% for item in seq -%} + {{ loop.previtem|default('x') }}-{{ item }}-{{ loop.nextitem|default('x') }}| + {%- endfor %} + """ + ) + let result = try template.render(["seq": Array(0 ... 3)]) + XCTAssertEqual(result, "x-0-1|0-1-2|1-2-3|2-3-x|") + } + + func testScope() throws { + let template = try Template("{% for item in seq %}{% endfor %}{{ item }}") + let result = try template.render(["seq": Array(0 ... 9)]) + XCTAssertEqual(result, "") + } + + func testVarlen() throws { + let template = try Template("{% for item in iter %}{{ item }}{% endfor %}") + let result = try template.render(["iter": Array(0 ... 4)]) + XCTAssertEqual(result, "01234") + } + + func testNoniter() throws { + let template = try Template("{% for item in none %}...{% endfor %}") + XCTAssertThrowsError(try template.render(["none": nil])) + } + + // TODO: Make this test pass + // func testRecursive() throws { + // let template = try Template( + // """ + // {% for item in seq recursive -%} + // [{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}] + // {%- endfor %} + // """ + // ) + // + // let data: [String: Any] = [ + // "seq": [ + // ["a": 1, "b": [["a": 1], ["a": 2]]], + // ["a": 2, "b": [["a": 1], ["a": 2]]], + // ["a": 3, "b": [["a": "a"]]], + // ] + // ] + // + // let result = try template.render(data) + // XCTAssertEqual(result, "[1<[1][2]>][2<[1][2]>][3<[a]>]") + // } + + func testLooploop() throws { + let template = try Template( + """ + {% for row in table %} + {%- set rowloop = loop -%} + {% for cell in row -%} + [{{ rowloop.index }}|{{ loop.index }}] + {%- endfor %} + {%- endfor %} + """ + ) + + let result = try template.render(["table": ["ab", "cd"]]) + XCTAssertEqual(result, "[1|1][1|2][2|1][2|2]") + } + + func testLoopFilter() throws { + let template = try Template( + "{% for item in range(10) if item is even %}[{{ item }}]{% endfor %}" + ) + let result = try template.render([:]) + XCTAssertEqual(result, "[0][2][4][6][8]") + + let template2 = try Template( + """ + {%- for item in range(10) if item is even %}[{{ loop.index }}:{{ item }}]{% endfor %} + """ + ) + let result2 = try template2.render([:]) + XCTAssertEqual(result2, "[1:0][2:2][3:4][4:6][5:8]") + } + + func testUnpacking() throws { + let template = try Template( + "{% for a, b, c in [[1, 2, 3]] %}{{ a }}|{{ b }}|{{ c }}{% endfor %}" + ) + let result = try template.render([:]) + XCTAssertEqual(result, "1|2|3") + } + + // TODO: Make this test pass + // func testRecursiveLookaround() throws { + // let template = try Template( + // """ + // {% for item in seq recursive -%} + // [{{ loop.previtem.a if loop.previtem is defined else 'x' }}.\ + // {{ item.a }}.\ + // {{ loop.nextitem.a if loop.nextitem is defined else 'x' }}\ + // {% if item.b %}<{{ loop(item.b) }}>{% endif %}] + // {%- endfor %} + // """ + // ) + // + // let data: [String: Any] = [ + // "seq": [ + // ["a": 1, "b": [["a": 1], ["a": 2]]], + // ["a": 2, "b": [["a": 1], ["a": 2]]], + // ["a": 3, "b": [["a": "a"]]], + // ] + // ] + // + // let result = try template.render(data) + // XCTAssertEqual(result, "[x.1.2<[x.1.2][1.2.x]>][1.2.3<[x.1.2][1.2.x]>][2.3.x<[x.a.x]>]") + // } + + // TODO: Make this test pass + // func testRecursiveDepth0() throws { + // let template = try Template( + // """ + // {% for item in seq recursive -%} + // [{{ loop.depth0 }}:{{ item.a }}\ + // {% if item.b %}<{{ loop(item.b) }}>{% endif %}] + // {%- endfor %} + // """ + // ) + // + // let data: [String: Any] = [ + // "seq": [ + // ["a": 1, "b": [["a": 1], ["a": 2]]], + // ["a": 2, "b": [["a": 1], ["a": 2]]], + // ["a": 3, "b": [["a": "a"]]], + // ] + // ] + // + // let result = try template.render(data) + // XCTAssertEqual(result, "[0:1<[1:1][1:2]>][0:2<[1:1][1:2]>][0:3<[1:a]>]") + // } + + // TODO: Make this test pass + // func testRecursiveDepth() throws { + // let template = try Template( + // """ + // {% for item in seq recursive -%} + // [{{ loop.depth }}:{{ item.a }}\ + // {% if item.b %}<{{ loop(item.b) }}>{% endif %}] + // {%- endfor %} + // """ + // ) + // + // let data: [String: Any] = [ + // "seq": [ + // ["a": 1, "b": [["a": 1], ["a": 2]]], + // ["a": 2, "b": [["a": 1], ["a": 2]]], + // ["a": 3, "b": [["a": "a"]]], + // ] + // ] + // + // let result = try template.render(data) + // XCTAssertEqual(result, "[1:1<[2:1][2:2]>][1:2<[2:1][2:2]>][1:3<[2:a]>]") + // } + + // TODO: Make this test pass + // func testReversedBug() throws { + // let template = try Template( + // """ + // {% for i in items %}{{ i }}\ + // {% if not loop.last %},{% endif %}\ + // {% endfor %} + // """ + // ) + // let result = try template.render(["items": [3, 2, 1].reversed()]) + // XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "1,2,3") + // } + + // TODO: Make this test pass + // func testLoopErrors() throws { + // // Test accessing loop variable before loop starts + // let template1 = try Template( + // """ + // {% for item in [1] if loop.index == 0 %}...{% endfor %} + // """ + // ) + // XCTAssertThrowsError(try template1.render([:])) + // + // // Test accessing loop in else block + // let template2 = try Template( + // """ + // {% for item in [] %}...{% else %}{{ loop }}{% endfor %} + // """ + // ) + // let result = try template2.render([:]) + // XCTAssertEqual(result, "") + // } + + func testScopedSpecialVar() throws { + let template = try Template( + """ + {% for s in seq %}[{{ loop.first }}\ + {% for c in s %}|{{ loop.first }}{% endfor %}]\ + {% endfor %} + """ + ) + let result = try template.render(["seq": ["ab", "cd"]]) + XCTAssertEqual(result, "[true|true|false][false|true|false]") + } + + func testScopedLoopVar() throws { + let template1 = try Template( + """ + {% for x in seq %}{{ loop.first }}\ + {% for y in seq %}{% endfor %}\ + {% endfor %} + """ + ) + let result1 = try template1.render(["seq": "ab"]) + XCTAssertEqual(result1, "truefalse") + + let template2 = try Template( + """ + {% for x in seq %}\ + {% for y in seq %}{{ loop.first }}\ + {% endfor %}\ + {% endfor %} + """ + ) + let result2 = try template2.render(["seq": "ab"]) + XCTAssertEqual(result2, "truefalsetruefalse") + } + + // TODO: Make this test pass + // func testRecursiveEmptyLoopIter() throws { + // let template = try Template( + // """ + // {%- for item in foo recursive -%}\ + // {%- endfor -%} + // """ + // ) + // let result = try template.render(["foo": []]) + // XCTAssertEqual(result, "") + // } + + // TODO: Make this test pass + // func testCallInLoop() throws { + // let template = try Template( + // """ + // {%- macro do_something() -%} + // [{{ caller() }}] + // {%- endmacro %} + // + // {%- for i in [1, 2, 3] %} + // {%- call do_something() -%} + // {{ i }} + // {%- endcall %} + // {%- endfor -%} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result, "[1][2][3]") + // } +} + +final class MacroTests: XCTestCase { + func testSimpleMacro() throws { + let template = try Template( + """ + {% macro say_hello(name) %}Hello {{ name }}!{% endmacro %} + {{ say_hello('Peter') }} + """ + ) + let result = try template.render([:]) + XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "Hello Peter!") + } + + func testMacroScoping() throws { + let template = try Template( + """ + {% macro level1(data1) %} + {% macro level2(data2) %}{{ data1 }}|{{ data2 }}{% endmacro %} + {{ level2('bar') }}{% endmacro %} + {{ level1('foo') }} + """ + ) + let result = try template.render([:]) + XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "foo|bar") + } + + // TODO: Make this test pass + // func testMacroArguments() throws { + // let template = try Template( + // """ + // {% macro m(a, b, c='c', d='d') %}{{ a }}|{{ b }}|{{ c }}|{{ d }}{% endmacro %} + // {{ m() }}|{{ m('a') }}|{{ m('a', 'b') }}|{{ m(1, 2, 3) }} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result, "||c|d|a||c|d|a|b|c|d|1|2|3|d") + // } + + func testCallself() throws { + let template = try Template( + """ + {% macro foo(x) %}{{ x }}{% if x > 1 %}|{{ foo(x - 1) }}{% endif %}{% endmacro %} + {{ foo(5) }} + """ + ) + let result = try template.render([:]) + XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "5|4|3|2|1") + } + + // TODO: Make this test pass + // func testArgumentsDefaultsNonsense() throws { + // // Test that macro with invalid argument defaults throws error + // let template = try Template( + // """ + // {% macro m(a, b=1, c) %}a={{ a }}, b={{ b }}, c={{ c }}{% endmacro %} + // """ + // ) + // XCTAssertThrowsError(try template.render([:])) + // } + + // TODO: Make this test pass + // func testCallerDefaultsNonsense() throws { + // let template = try Template( + // """ + // {% macro a() %}{{ caller() }}{% endmacro %} + // {% call(x, y=1, z) a() %}{% endcall %} + // """ + // ) + // XCTAssertThrowsError(try template.render([:])) + // } + + // TODO: Make this test pass + // func testVarargs() throws { + // let template = try Template( + // """ + // {% macro test() %}{{ varargs|join('|') }}{% endmacro %}\ + // {{ test(1, 2, 3) }} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result, "1|2|3") + // } + + // TODO: Make this test pass + // func testSimpleCall() throws { + // let template = try Template( + // """ + // {% macro test() %}[[{{ caller() }}]]{% endmacro %}\ + // {% call test() %}data{% endcall %} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result, "[[data]]") + // } + + // TODO: Make this test pass + // func testComplexCall() throws { + // let template = try Template( + // """ + // {% macro test() %}[[{{ caller('data') }}]]{% endmacro %}\ + // {% call(data) test() %}{{ data }}{% endcall %} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result, "[[data]]") + // } + + // TODO: Make this test pass + // func testCallerUndefined() throws { + // let template = try Template( + // """ + // {% set caller = 42 %}\ + // {% macro test() %}{{ caller is not defined }}{% endmacro %}\ + // {{ test() }} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result, "true") + // } +} + +final class SetTests: XCTestCase { + // MARK: - Set Tests + + func testNormalSet() throws { + let template = try Template("{% set foo = 1 %}{{ foo }}") + let result = try template.render([:]) + XCTAssertEqual(result, "1") + } + + // TODO: Make this test pass + // func testBlockSet() throws { + // let template = try Template("{% set foo %}42{% endset %}{{ foo }}") + // let result = try template.render([:]) + // XCTAssertEqual(result, "42") + // } + + func testNamespace() throws { + let template = try Template( + """ + {% set ns = namespace() %}{% set ns.bar = '42' %}{{ ns.bar }} + """ + ) + let result = try template.render([:]) + XCTAssertEqual(result, "42") + } + + // TODO: Make this test pass + // func testNamespaceLoop() throws { + // let template = try Template( + // """ + // {% set ns = namespace(found=false) %}\ + // {% for x in range(4) %}\ + // {% if x == v %}\ + // {% set ns.found = true %}\ + // {% endif %}\ + // {% endfor %}\ + // {{ ns.found }} + // """ + // ) + // + // let result1 = try template.render(["v": 3]) + // XCTAssertEqual(result1, "true") + // + // let result2 = try template.render(["v": 4]) + // XCTAssertEqual(result2, "false") + // } + + // TODO: Make this test pass + // func testNamespaceBlock() throws { + // let template = try Template( + // """ + // {% set ns = namespace() %}{% set ns.bar %}42{% endset %}{{ ns.bar }} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result, "42") + // } + + // TODO: Make this test pass + // func testInitNamespace() throws { + // let template = try Template( + // """ + // {% set ns = namespace(d, self=37) %} + // {% set ns.b = 42 %} + // {{ ns.a }}|{{ ns.self }}|{{ ns.b }} + // """ + // ) + // let result = try template.render(["d": ["a": 13]]) + // XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "13|37|42") + // } + + // TODO: Make this test pass + // func testNamespaceMacro() throws { + // let template = try Template( + // """ + // {% set ns = namespace() %} + // {% set ns.a = 13 %} + // {% macro magic(x) %} + // {% set x.b = 37 %} + // {% endmacro %} + // {{ magic(ns) }} + // {{ ns.a }}|{{ ns.b }} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "13|37") + // } + + // TODO: Make this test pass + // func testNamespaceSetTuple() throws { + // let template = try Template( + // """ + // {% set ns = namespace(a=12, b=36) %} + // {% set ns.a, ns.b = ns.a + 1, ns.b + 1 %} + // {{ ns.a }}|{{ ns.b }} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "13|37") + // } + + // TODO: Make this test pass + // func testBlockEscaping() throws { + // let template = try Template( + // """ + // {% set foo %}{{ test }}{% endset %}\ + // foo: {{ foo }} + // """ + // ) + // let result = try template.render(["test": ""]) + // XCTAssertEqual( + // result.trimmingCharacters(in: .whitespaces), + // "foo: <unsafe>" + // ) + // } + + // TODO: Make this test pass + // func testBlockEscapingFiltered() throws { + // let template = try Template( + // """ + // {% set foo | trim %}{{ test }} {% endset %}\ + // foo: {{ foo }} + // """ + // ) + // let result = try template.render(["test": ""]) + // XCTAssertEqual( + // result.trimmingCharacters(in: .whitespaces), + // "foo: <unsafe>" + // ) + // } + + // TODO: Make this test pass + // func testBlockFiltered() throws { + // let template = try Template( + // """ + // {% set foo | trim | length | string %} 42 {% endset %}\ + // {{ foo }} + // """ + // ) + // let result = try template.render([:]) + // XCTAssertEqual(result.trimmingCharacters(in: .whitespaces), "2") + // } + + // TODO: Make this test pass + // func testSetInvalid() throws { + // // Test invalid set syntax + // let template1 = try Template("{% set foo['bar'] = 1 %}") + // XCTAssertThrowsError(try template1.render([:])) + // + // // Test setting attribute on non-namespace + // let template2 = try Template("{% set foo.bar = 1 %}") + // XCTAssertThrowsError(try template2.render(["foo": [:]])) + // } + + func testNamespaceRedefined() throws { + let template = try Template( + """ + {% set ns = namespace() %}\ + {% set ns.bar = 'hi' %} + """ + ) + XCTAssertThrowsError(try template.render(["namespace": [String: Any].self])) + } +} + +// TODO: Make these tests pass +//final class WithTests: XCTestCase { +// func testWith() throws { +// let template = try Template( +// """ +// {% with a=42, b=23 -%} +// {{ a }} = {{ b }} +// {% endwith -%} +// {{ a }} = {{ b }} +// """ +// ) +// let result = try template.render(["a": 1, "b": 2]) +// let lines = result.split(separator: "\n").map { $0.trimmingCharacters(in: .whitespaces) } +// XCTAssertEqual(lines, ["42 = 23", "1 = 2"]) +// } +// +// func testWithArgumentScoping() throws { +// let template = try Template( +// """ +// {%- with a=1, b=2, c=b, d=e, e=5 -%} +// {{ a }}|{{ b }}|{{ c }}|{{ d }}|{{ e }} +// {%- endwith -%} +// """ +// ) +// let result = try template.render(["b": 3, "e": 4]) +// XCTAssertEqual(result, "1|2|3|4|5") +// } +//} diff --git a/Tests/FilterTests.swift b/Tests/FilterTests.swift new file mode 100644 index 0000000..362b249 --- /dev/null +++ b/Tests/FilterTests.swift @@ -0,0 +1,1115 @@ +// +// FilterTests.swift +// Jinja +// +// Created by Anthony DePasquale on 07.01.2025. +// + +// Adapted from https://github.com/pallets/jinja/blob/main/tests/test_filters.py + +import XCTest +import OrderedCollections + +@testable import Jinja + +final class FilterTests: XCTestCase { + func testFilters() throws { + // Helper function to run tests for a filter + func runTest( + filterName: String, + input: Any, + args: [Any?] = [], + expected: Any, + file: StaticString = #file, + line: UInt = #line + ) throws { + let env = Environment() + + // Convert input to RuntimeValue + guard let input = try? env.convertToRuntimeValues(input: input) else { + XCTFail( + "Failed to convert input \(input) to RuntimeValue in test for \(filterName)", + file: file, + line: line + ) + return + } + + // Set the input value in the environment + try env.set(name: "input", value: input) + + // Set filter arguments in the environment + for (index, arg) in args.enumerated() { + if let arg { + try env.set(name: "arg\(index)", value: arg) + } + } + + // Construct the filter arguments for direct call + var filterArgs: [any RuntimeValue] = [input] + for (index, _) in args.enumerated() { + filterArgs.append(env.lookupVariable(name: "arg\(index)")) + } + + // Get the filter function from the environment + guard let filter = env.filters[filterName] else { + XCTFail("Filter not found: \(filterName)", file: file, line: line) + return + } + + // Call the filter function directly with the input and arguments + let result = try filter(filterArgs, env) + + // Perform assertions based on the expected type + if let expectedString = expected as? String { + XCTAssertEqual( + (result as? StringValue)?.value, + expectedString, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedInt = expected as? Int { + XCTAssertEqual( + (result as? NumericValue)?.value as? Int, + expectedInt, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedDouble = expected as? Double { + XCTAssertEqual( + (result as? NumericValue)?.value as? Double, + expectedDouble, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedBool = expected as? Bool { + XCTAssertEqual( + (result as? BooleanValue)?.value, + expectedBool, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if expected is UndefinedValue { + XCTAssertTrue( + result is UndefinedValue, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedArray = expected as? [String] { + guard let resultArray = (result as? ArrayValue)?.value else { + XCTFail( + "\(filterName) filter failed: Expected [String], got \(type(of: result)), value: \(result)", + file: file, + line: line + ) + return + } + let resultStrings = resultArray.compactMap { value -> String? in + if let stringValue = value as? StringValue { + return stringValue.value + } else if value is NullValue { + return "None" + } + return nil + } + XCTAssertEqual( + resultStrings, + expectedArray, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedArray = expected as? [Int] { + guard let resultArray = (result as? ArrayValue)?.value else { + XCTFail( + "\(filterName) filter failed: Expected [Int], got \(type(of: result))", + file: file, + line: line + ) + return + } + let resultInts = resultArray.compactMap { ($0 as? NumericValue)?.value as? Int } + XCTAssertEqual( + resultInts, + expectedArray, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedArray = expected as? [[String]] { + guard let resultArray = (result as? ArrayValue)?.value else { + XCTFail( + "\(filterName) filter failed: Expected [[String]], got \(type(of: result))", + file: file, + line: line + ) + return + } + let resultArrays = resultArray.compactMap { value -> [String]? in + if let arrayValue = value as? ArrayValue { + return arrayValue.value.compactMap { ($0 as? StringValue)?.value } + } else if let stringValue = value as? StringValue { + return [stringValue.value] + } + return nil + } + XCTAssertEqual( + resultArrays, + expectedArray, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedArray = expected as? [[Int]] { + guard let resultArray = (result as? ArrayValue)?.value as? [ArrayValue] else { + XCTFail( + "\(filterName) filter failed: Expected [[Int]], got \(type(of: result))", + file: file, + line: line + ) + return + } + let resultInts = resultArray.map { $0.value.compactMap { ($0 as? NumericValue)?.value as? Int } } + XCTAssertEqual( + resultInts, + expectedArray, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedDict = expected as? [String: Any] { + guard let resultDict = (result as? ObjectValue)?.value else { + XCTFail( + "\(filterName) filter failed: Expected [String: Any], got \(type(of: result))", + file: file, + line: line + ) + return + } + XCTAssertEqual( + resultDict.count, + expectedDict.count, + "\(filterName) filter failed: Dictionary count mismatch", + file: file, + line: line + ) + for (key, expectedValue) in expectedDict { + guard let resultValue = resultDict[key] else { + XCTFail( + "\(filterName) filter failed: Missing key '\(key)' in result", + file: file, + line: line + ) + return + } + if let expectedString = expectedValue as? String { + XCTAssertEqual( + (resultValue as? StringValue)?.value, + expectedString, + "\(filterName) filter failed for key '\(key)'", + file: file, + line: line + ) + } else if let expectedInt = expectedValue as? Int { + XCTAssertEqual( + ((resultValue as? NumericValue)?.value as? Int), + expectedInt, + "\(filterName) filter failed for key '\(key)'", + file: file, + line: line + ) + } else if let expectedDouble = expectedValue as? Double { + XCTAssertEqual( + ((resultValue as? NumericValue)?.value as? Double), + expectedDouble, + "\(filterName) filter failed for key '\(key)'", + file: file, + line: line + ) + } else if let expectedBool = expectedValue as? Bool { + XCTAssertEqual( + (resultValue as? BooleanValue)?.value, + expectedBool, + "\(filterName) filter failed for key '\(key)'", + file: file, + line: line + ) + } else if expectedValue is UndefinedValue { + XCTAssertTrue( + resultValue is UndefinedValue, + "\(filterName) filter failed for key '\(key)'", + file: file, + line: line + ) + } else { + XCTFail( + "\(filterName) filter failed: Unsupported type for key '\(key)'", + file: file, + line: line + ) + } + } + } else if let expectedArray = expected as? [(String, Any)] { + guard let resultArray = (result as? ArrayValue)?.value as? [ArrayValue] else { + XCTFail( + "\(filterName) filter failed: Expected [(String, Any)], got \(type(of: result))", + file: file, + line: line + ) + return + } + + XCTAssertEqual( + resultArray.count, + expectedArray.count, + "\(filterName) filter failed", + file: file, + line: line + ) + + for (index, expectedTuple) in expectedArray.enumerated() { + let resultTuple = resultArray[index].value + + guard resultTuple.count == 2 else { + XCTFail( + "\(filterName) filter failed at index \(index): Result tuple does not have 2 elements", + file: file, + line: line + ) + return + } + + XCTAssertEqual( + (resultTuple[0] as? StringValue)?.value, + expectedTuple.0, + "\(filterName) filter failed at index \(index)", + file: file, + line: line + ) + + if let expectedInt = expectedTuple.1 as? Int { + XCTAssertEqual( + ((resultTuple[1] as? NumericValue)?.value as? Int), + expectedInt, + "\(filterName) filter failed at index \(index)", + file: file, + line: line + ) + } else if let expectedString = expectedTuple.1 as? String { + XCTAssertEqual( + (resultTuple[1] as? StringValue)?.value, + expectedString, + "\(filterName) filter failed at index \(index)", + file: file, + line: line + ) + } else { + XCTFail( + "\(filterName) filter failed: Unsupported type for tuple element at index \(index)", + file: file, + line: line + ) + } + } + } else if let expectedMixedArray = expected as? [Any] { + guard let resultArray = (result as? ArrayValue)?.value else { + XCTFail( + "\(filterName) filter failed: Expected [Any], got \(type(of: result))", + file: file, + line: line + ) + return + } + + // Convert both arrays to strings for comparison since they may contain mixed types + let resultStrings = resultArray.map { value -> String in + if let arrayValue = value as? ArrayValue { + return "[" + + arrayValue.value.map { + if let strValue = $0 as? StringValue { + return strValue.value + } + return String(describing: $0) + }.joined(separator: ", ") + "]" + } else if let stringValue = value as? StringValue { + return stringValue.value + } else { + return String(describing: value) + } + } + + let expectedStrings = expectedMixedArray.map { value -> String in + if let array = value as? [String] { + return "[" + array.joined(separator: ", ") + "]" + } else { + return String(describing: value) + } + } + + XCTAssertEqual( + resultStrings, + expectedStrings, + "\(filterName) filter failed", + file: file, + line: line + ) + } else if let expectedGroups = expected as? [[String: Any]] { + // For "groupby" filter + // Convert both expected and actual results to JSON strings for comparison + let expectedJSON = try toJSON(try env.convertToRuntimeValues(input: expectedGroups)) + let actualJSON = try toJSON(result) + + XCTAssertEqual( + expectedJSON, + actualJSON, + "\(filterName) filter failed: Expected \(expectedJSON) but got \(actualJSON)", + file: file, + line: line + ) + } else { + XCTFail( + "\(filterName) filter failed: Unsupported expected type \(type(of: expected))", + file: file, + line: line + ) + } + } + + // Test abs + try runTest(filterName: "abs", input: -1, expected: 1) + try runTest(filterName: "abs", input: 1, expected: 1) + try runTest(filterName: "abs", input: -3.14, expected: 3.14) + try runTest(filterName: "abs", input: 3.14, expected: 3.14) + + // Test attr + try runTest( + filterName: "attr", + input: ["name": "John"], + args: ["name"], + expected: "John" + ) + try runTest( + filterName: "attr", + input: ["age": 30], + args: ["age"], + expected: 30 + ) + try runTest( + filterName: "attr", + input: ["name": "John"], + args: ["age"], + expected: UndefinedValue() + ) + + // Test batch + try runTest( + filterName: "batch", + input: [1, 2, 3, 4, 5, 6, 7, 8, 9], + args: [3], + expected: [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + ) + try runTest( + filterName: "batch", + input: [1, 2, 3, 4, 5, 6, 7, 8, 9], + args: [3, 0], + expected: [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + ) + try runTest( + filterName: "batch", + input: [1, 2, 3, 4, 5, 6, 7, 8], + args: [3, 0], + expected: [[1, 2, 3], [4, 5, 6], [7, 8, 0]] + ) + + // Test capitalize + try runTest(filterName: "capitalize", input: "foo bar", expected: "Foo bar") + try runTest(filterName: "capitalize", input: "FOO BAR", expected: "Foo bar") + + // Test center + try runTest( + filterName: "center", + input: "foo", + expected: String(repeating: " ", count: 38) + "foo" + String(repeating: " ", count: 39) + ) // Default width 80 + try runTest(filterName: "center", input: "foo", args: [NumericValue(value: 11)], expected: " foo ") + try runTest(filterName: "center", input: "foo", args: [NumericValue(value: 5)], expected: " foo ") + try runTest(filterName: "center", input: "foo", args: [NumericValue(value: 4)], expected: "foo ") + try runTest(filterName: "center", input: "foo", args: [NumericValue(value: 3)], expected: "foo") + try runTest(filterName: "center", input: "foo", args: [NumericValue(value: 2)], expected: "foo") + + // Test count + try runTest(filterName: "count", input: "Hello", expected: 5) + try runTest(filterName: "count", input: [1, 2, 3].map { NumericValue(value: $0) }, expected: 3) + try runTest( + filterName: "count", + input: ObjectValue(value: ["name": StringValue(value: "John"), "age": NumericValue(value: 30)]), + expected: 2 + ) + + // Test default + try runTest(filterName: "default", input: UndefinedValue(), expected: "") + try runTest(filterName: "default", input: UndefinedValue(), args: ["foo"], expected: "foo") + try runTest(filterName: "default", input: false, args: ["foo", true], expected: "foo") + try runTest(filterName: "default", input: true, args: ["foo", true], expected: "true") + try runTest(filterName: "default", input: "bar", args: ["foo"], expected: "bar") + + // Test dictsort + try runTest( + filterName: "dictsort", + input: OrderedDictionary(dictionaryLiteral: ("f", 5), ("b", 4), ("c", 3), ("d", 2), ("a", 1)), + expected: [("a", 1), ("b", 4), ("c", 3), ("d", 2), ("f", 5)] + ) + try runTest( + filterName: "dictsort", + input: OrderedDictionary(dictionaryLiteral: ("f", 5), ("b", 4), ("c", 3), ("d", 2), ("a", 1)), + args: [true], + expected: [("a", 1), ("b", 4), ("c", 3), ("d", 2), ("f", 5)] + ) + try runTest( + filterName: "dictsort", + input: OrderedDictionary(dictionaryLiteral: ("f", 5), ("b", 4), ("c", 3), ("d", 2), ("a", 1)), + args: [false, "value"], + expected: [("a", 1), ("d", 2), ("c", 3), ("b", 4), ("f", 5)] + ) + try runTest( + filterName: "dictsort", + input: OrderedDictionary(dictionaryLiteral: ("f", 5), ("b", 4), ("c", 3), ("d", 2), ("a", 1)), + args: [false, "key", true], + expected: [("f", 5), ("d", 2), ("c", 3), ("b", 4), ("a", 1)] + ) + try runTest( + filterName: "dictsort", + input: OrderedDictionary(dictionaryLiteral: ("f", 5), ("b", 4), ("c", 3), ("d", 2), ("a", 1)), + args: [false, "value", true], + expected: [("f", 5), ("b", 4), ("c", 3), ("d", 2), ("a", 1)] + ) + + // Test escape + try runTest(filterName: "escape", input: "", expected: "<foo>") + try runTest(filterName: "escape", input: "foo & bar", expected: "foo & bar") + + // Test filesizeformat + try runTest(filterName: "filesizeformat", input: 100, expected: "100 Bytes") + try runTest(filterName: "filesizeformat", input: 1000, expected: "1.0 kB") + try runTest(filterName: "filesizeformat", input: 1_000_000, expected: "1.0 MB") + try runTest(filterName: "filesizeformat", input: 1_000_000_000, expected: "1.0 GB") + try runTest( + filterName: "filesizeformat", + input: 1_000_000_000_000, + expected: "1.0 TB" + ) + try runTest(filterName: "filesizeformat", input: 300, expected: "300 Bytes") + try runTest(filterName: "filesizeformat", input: 3000, expected: "3.0 kB") + try runTest(filterName: "filesizeformat", input: 3_000_000, expected: "3.0 MB") + try runTest(filterName: "filesizeformat", input: 3_000_000_000, expected: "3.0 GB") + try runTest( + filterName: "filesizeformat", + input: 3_000_000_000_000, + expected: "3.0 TB" + ) + try runTest( + filterName: "filesizeformat", + input: 100, + args: [true], + expected: "100 Bytes" + ) + try runTest( + filterName: "filesizeformat", + input: 1000, + args: [true], + expected: "1000 Bytes" + ) + try runTest( + filterName: "filesizeformat", + input: 1_000_000, + args: [true], + expected: "976.6 KiB" + ) + try runTest( + filterName: "filesizeformat", + input: 1_000_000_000, + args: [true], + expected: "953.7 MiB" + ) + try runTest( + filterName: "filesizeformat", + input: 1_000_000_000_000, + args: [true], + expected: "931.3 GiB" + ) + try runTest( + filterName: "filesizeformat", + input: 300, + args: [true], + expected: "300 Bytes" + ) + try runTest( + filterName: "filesizeformat", + input: 3000, + args: [true], + expected: "2.9 KiB" + ) + try runTest( + filterName: "filesizeformat", + input: 3_000_000, + args: [true], + expected: "2.9 MiB" + ) + + // Test first + try runTest(filterName: "first", input: [1, 2, 3], expected: 1) + try runTest(filterName: "first", input: ["a", "b", "c"], expected: "a") + try runTest(filterName: "first", input: [], expected: UndefinedValue()) + + // Test float + try runTest(filterName: "float", input: 42, expected: 42.0) + try runTest(filterName: "float", input: 42.5, expected: 42.5) + try runTest(filterName: "float", input: "42", expected: 0.0) + try runTest(filterName: "float", input: "42.5", expected: 0.0) + + // Test forceescape + try runTest(filterName: "forceescape", input: "", expected: "<foo>") + try runTest(filterName: "forceescape", input: "foo & bar", expected: "foo & bar") + + // Test format + try runTest(filterName: "format", input: "%s %s", args: ["Hello", "World"], expected: "Hello World") + try runTest(filterName: "format", input: "%d", args: [123], expected: "123") + + // TODO: Test groupby + + // Test indent + try runTest( + filterName: "indent", + input: "Hello\nWorld", + expected: "Hello\n World" + ) // Default: width=4, first=false, blank=false + try runTest( + filterName: "indent", + input: "Hello\nWorld", + args: [2], + expected: "Hello\n World" + ) // width=2 + try runTest( + filterName: "indent", + input: "Hello\nWorld", + args: [4, true], + expected: " Hello\n World" + ) // first=true + try runTest( + filterName: "indent", + input: "\nfoo bar\n\"baz\"\n", + args: [2, false, false], + expected: "\n foo bar\n \"baz\"\n" + ) // blank=false + try runTest( + filterName: "indent", + input: "\nfoo bar\n\"baz\"\n", + args: [2, false, true], + expected: "\n foo bar\n \"baz\"\n " + ) // blank=true + try runTest( + filterName: "indent", + input: "\nfoo bar\n\"baz\"\n", + args: [2, true, false], + expected: " \n foo bar\n \"baz\"\n" + ) // first=true, blank=false + try runTest( + filterName: "indent", + input: "\nfoo bar\n\"baz\"\n", + args: [2, true, true], + expected: " \n foo bar\n \"baz\"\n " + ) // first=true, blank=true + try runTest( + filterName: "indent", + input: "jinja", + expected: "jinja" + ) // Single line, no indent + try runTest( + filterName: "indent", + input: "jinja", + args: [4, true], + expected: " jinja" + ) // Single line, first=true + try runTest( + filterName: "indent", + input: "jinja", + args: [4, false, true], + expected: "jinja" + ) // Single line, blank=true (no effect) + try runTest( + filterName: "indent", + input: "jinja\nflask", + args: [">>> ", true], + expected: ">>> jinja\n>>> flask" + ) // String width, first=true + + // Test int + try runTest(filterName: "int", input: 42.0, expected: 42) + try runTest(filterName: "int", input: 42.5, expected: 42) + try runTest(filterName: "int", input: "42", expected: 42) + try runTest(filterName: "int", input: "42.5", expected: 42) + + // Test items + // Test with dictionary + try runTest( + filterName: "items", + input: OrderedDictionary( + dictionaryLiteral: ("0", "a"), + ("1", "b"), + ("2", "c") + ), + expected: [ + ("0", "a"), + ("1", "b"), + ("2", "c"), + ] + ) + // Test with undefined value + try runTest( + filterName: "items", + input: UndefinedValue(), + expected: [] + ) + // Test with invalid input (should throw error) + XCTAssertThrowsError( + try runTest( + filterName: "items", + input: [1, 2, 3], // Array instead of mapping + expected: [] + ) + ) { error in + XCTAssertEqual( + error as? JinjaError, + .runtime("Can only get item pairs from a mapping.") + ) + } + + // Test join + try runTest(filterName: "join", input: [1, 2, 3], expected: "123") + try runTest(filterName: "join", input: [1, 2, 3], args: [","], expected: "1,2,3") + try runTest(filterName: "join", input: ["a", "b", "c"], args: ["-"], expected: "a-b-c") + + // Test last + try runTest(filterName: "last", input: [1, 2, 3], expected: 3) + try runTest(filterName: "last", input: ["a", "b", "c"], expected: "c") + try runTest(filterName: "last", input: [], expected: UndefinedValue()) + + // Test length + try runTest(filterName: "length", input: "Hello", expected: 5) + try runTest(filterName: "length", input: [1, 2, 3], expected: 3) + try runTest(filterName: "length", input: ["name": "John", "age": 30], expected: 2) + + // Test list + try runTest(filterName: "list", input: [1, 2, 3], expected: [1, 2, 3]) + try runTest(filterName: "list", input: ["a", "b", "c"], expected: ["a", "b", "c"]) + + // Test lower + try runTest(filterName: "lower", input: "FOO", expected: "foo") + try runTest(filterName: "lower", input: "Foo", expected: "foo") + + // Test map + // Test simple map with int conversion + try runTest( + filterName: "map", + input: ["1", "2", "3"], + args: [StringValue(value: "int")], + expected: [1, 2, 3] + ) + + // TODO: Test `map` with `sum` (currently failing, may require changes to `map` or `sum`) + // try runFilterTest( + // filterName: "map", + // input: [[1, 2], [3], [4, 5, 6]], + // args: [StringValue(value: "sum")], + // expected: [3, 3, 15] + // ) + + // Test attribute map + try runTest( + filterName: "map", + input: [ + ["username": "john"], + ["username": "jane"], + ["username": "mike"], + ], + args: [ + ObjectValue(value: [ + "attribute": StringValue(value: "username") + ]) + ], + expected: ["john", "jane", "mike"] + ) + + // Test map with default value + try runTest( + filterName: "map", + input: [ + ["firstname": "john", "lastname": "lennon"], + ["firstname": "jane", "lastname": "edwards"], + ["firstname": "jon", "lastname": UndefinedValue()], + ["firstname": "mike"], + ], + args: [ + ObjectValue(value: [ + "attribute": StringValue(value: "lastname"), + "default": StringValue(value: "smith"), + ]) + ], + expected: ["lennon", "edwards", "None", "smith"] + ) + + // Test map with list default value + try runTest( + filterName: "map", + input: [ + ["firstname": "john", "lastname": "lennon"], + ["firstname": "jane", "lastname": "edwards"], + ["firstname": "jon", "lastname": UndefinedValue()], + ["firstname": "mike"], + ], + args: [ + ObjectValue(value: [ + "attribute": StringValue(value: "lastname"), + "default": ArrayValue(value: [ + StringValue(value: "smith"), + StringValue(value: "x"), + ]), + ]) + ], + expected: ["lennon", "edwards", "None", ["smith", "x"]] + ) + + // Test map with empty string default value + try runTest( + filterName: "map", + input: [ + ["firstname": "john", "lastname": "lennon"], + ["firstname": "jane", "lastname": "edwards"], + ["firstname": "jon", "lastname": UndefinedValue()], + ["firstname": "mike"], + ], + args: [ + ObjectValue(value: [ + "attribute": StringValue(value: "lastname"), + "default": StringValue(value: ""), + ]) + ], + expected: ["lennon", "edwards", "None", ""] + ) + + // Test min + try runTest(filterName: "min", input: [3, 1, 4, 2], expected: 1) + try runTest(filterName: "min", input: ["b", "a", "d", "c"], expected: "a") + try runTest(filterName: "min", input: [], expected: UndefinedValue()) + + // Test max + try runTest(filterName: "max", input: [3, 1, 4, 2], expected: 4) + try runTest(filterName: "max", input: ["b", "a", "d", "c"], expected: "d") + try runTest(filterName: "max", input: [], expected: UndefinedValue()) + + // TODO: Figure out how to test "pprint", given that Swift 5.10 doesn't preserve the key order in dictionaries + + // TODO: Figure out how to test "random" filter + + // Test reject + try runTest( + filterName: "reject", + input: [1, 2, 3, 4, 5], + args: ["even"], + expected: [1, 3, 5] + ) + + // TODO: Test rejectattr + // try runFilterTest( + // filterName: "rejectattr", + // input: [ + // ["name": "John", "admin": true], + // ["name": "Jane", "admin": false], + // ], + // args: ["admin"], + // expected: [ + // ["admin": false, "name": "Jane"] + // ] + // ) + + // Test replace + try runTest( + filterName: "replace", + input: "Hello World", + args: ["World", "Jinja"], + expected: "Hello Jinja" + ) + try runTest( + filterName: "replace", + input: "aaaa", + args: ["a", "b", 2], + expected: "bbbb" + ) + + // Test reverse + try runTest(filterName: "reverse", input: [1, 2, 3], expected: [3, 2, 1]) + try runTest(filterName: "reverse", input: ["a", "b", "c"], expected: ["c", "b", "a"]) + + // Test round + try runTest(filterName: "round", input: 42.55, expected: 43.0) + try runTest(filterName: "round", input: 42.55, args: [1], expected: 42.6) + try runTest(filterName: "round", input: 42.55, args: [1, "floor"], expected: 42.5) + try runTest(filterName: "round", input: 42.55, args: [1, "ceil"], expected: 42.6) + + // Test safe + try runTest(filterName: "safe", input: "", expected: "") + try runTest(filterName: "safe", input: "foo & bar", expected: "foo & bar") + + // Test select + try runTest( + filterName: "select", + input: [1, 2, 3, 4, 5], + args: ["even"], + expected: [2, 4] + ) + // TODO: Make this test pass + // try runFilterTest( + // filterName: "select", + // input: [ + // ["name": "John", "age": 30], + // ["name": "Jane", "age": 25], + // ], + // args: ["even"], + // expected: [["name": "John", "age": 30]] + // ) + + // TODO: Test selectattr + // try runFilterTest( + // filterName: "selectattr", + // input: [ + // ["name": "John", "admin": true], + // ["name": "Jane", "admin": false], + // ], + // args: ["admin"], + // expected: [["name": "John", "admin": true]] + // ) + // try runFilterTest( + // filterName: "selectattr", + // input: [ + // ["name": "John", "age": 30], + // ["name": "Jane", "age": 25], + // ], + // args: ["age", "equalto", 25], + // expected: [["name": "Jane", "age": 25]] + // ) + + // Test slice + try runTest( + filterName: "slice", + input: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + args: [3], + expected: [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]] + ) + try runTest( + filterName: "slice", + input: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + args: [3, 0], + expected: [[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 10, 0]] + ) + + // Test sort + try runTest(filterName: "sort", input: [3, 1, 4, 2], expected: [1, 2, 3, 4]) + try runTest(filterName: "sort", input: [3, 1, 4, 2], args: [true], expected: [4, 3, 2, 1]) + try runTest( + filterName: "sort", + input: ["b", "A", "d", "c"], + expected: ["A", "b", "c", "d"] + ) + try runTest( + filterName: "sort", + input: ["b", "A", "d", "c"], + args: [false, true], + expected: ["A", "b", "c", "d"] + ) + // TODO: Make these tests pass + // try runFilterTest( + // filterName: "sort", + // input: [ + // ["name": "John", "age": 30], + // ["name": "Jane", "age": 25], + // ], + // args: [false, false, "name"], + // expected: [ + // ["name": "Jane", "age": 25], + // ["name": "John", "age": 30], + // ] + // ) + // try runFilterTest( + // filterName: "sort", + // input: [ + // ["name": "John", "age": 30], + // ["name": "Jane", "age": 25], + // ], + // args: [false, false, "age"], + // expected: [ + // ["name": "Jane", "age": 25], + // ["name": "John", "age": 30], + // ] + // ) + + // Test string + try runTest(filterName: "string", input: 123, expected: "123") + try runTest(filterName: "string", input: true, expected: "true") + try runTest(filterName: "string", input: [1, 2, 3], expected: "[1, 2, 3]") + try runTest( + filterName: "string", + input: OrderedDictionary(dictionaryLiteral: ("a", 1), ("b", 2)), + expected: "{\"a\": 1, \"b\": 2}" + ) + + // Test striptags + try runTest( + filterName: "striptags", + input: "

Hello, World!

", + expected: "Hello, World!" + ) + try runTest( + filterName: "striptags", + input: "
Link", + expected: "Link" + ) + + // Test sum + try runTest(filterName: "sum", input: [1, 2, 3, 4, 5], expected: 15) + try runTest( + filterName: "sum", + input: [ + ["value": 1], + ["value": 2], + ["value": 3], + ], + args: ["value"], + expected: 6 + ) + try runTest(filterName: "sum", input: [1, 2, 3, 4, 5], args: [], expected: 15) + // TODO: Make this test pass + // try runFilterTest(filterName: "sum", input: [1, 2, 3, 4, 5], args: ["", 10], expected: 25) + + // Test title + try runTest(filterName: "title", input: "hello world", expected: "Hello World") + try runTest(filterName: "title", input: "HELLO WORLD", expected: "Hello World") + + // Test trim + try runTest(filterName: "trim", input: " hello ", expected: "hello") + try runTest(filterName: "trim", input: "\t hello \n ", expected: "hello") + + // Test truncate + try runTest(filterName: "truncate", input: "Hello World", expected: "Hello World") + try runTest(filterName: "truncate", input: "Hello World", args: [5], expected: "He...") + try runTest(filterName: "truncate", input: "Hello World", args: [5, true], expected: "He...") + try runTest(filterName: "truncate", input: "Hello World", args: [5, false], expected: "He...") + try runTest(filterName: "truncate", input: "Hello World", args: [5, false, "---"], expected: "He---") + try runTest(filterName: "truncate", input: "Hello Big World", args: [10, false], expected: "Hello...") + + // Test unique + try runTest(filterName: "unique", input: [2, 1, 2, 3, 4, 4], expected: [2, 1, 3, 4]) + try runTest(filterName: "unique", input: ["Foo", "foo", "bar"], expected: ["Foo", "bar"]) + try runTest( + filterName: "unique", + input: ["Foo", "foo", "bar"], + args: [true], + expected: ["Foo", "foo", "bar"] + ) + // TODO: Make these tests pass + // try runFilterTest( + // filterName: "unique", + // input: [ + // ["name": "foo", "id": 1], + // ["name": "foo", "id": 2], + // ["name": "bar", "id": 3], + // ], + // args: [false, "name"], + // expected: [["name": "foo", "id": 1], ["name": "bar", "id": 3]] + // ) + // try runFilterTest( + // filterName: "unique", + // input: [ + // ["name": "foo", "id": 1], + // ["name": "foo", "id": 2], + // ["name": "bar", "id": 3], + // ], + // args: [false, "id"], + // expected: [["name": "foo", "id": 1], ["name": "bar", "id": 3]] + // ) + try runTest( + filterName: "unique", + input: "abcba", + expected: ["a", "b", "c"] + ) + try runTest( //XCTAssertEqual failed: ("["a"]") is not equal to ("["a", "b", "c"]") - unique filter failed + filterName: "unique", + input: "abcba", + args: [false, 0], + expected: ["a", "b", "c"] + ) + + // Test upper + try runTest(filterName: "upper", input: "foo", expected: "FOO") + try runTest(filterName: "upper", input: "Foo", expected: "FOO") + + // TODO: Test urlencode + + // Test urlize + try runTest( + filterName: "urlize", + input: "http://www.example.com/", + expected: "http://www.example.com/" + ) + try runTest( + filterName: "urlize", + input: "www.example.com", + expected: "www.example.com" + ) + try runTest( + filterName: "urlize", + input: "http://www.example.com/", + args: [10], + expected: "http://www..." + ) + try runTest( + filterName: "urlize", + input: "http://www.example.com/", + args: [10, true], + expected: "http://www..." + ) + // TODO: Make this test pass + // try runFilterTest( + // filterName: "urlize", + // input: "http://www.example.com/", + // args: [10, true, "_blank"], + // expected: "http://www..." + // ) + + // Test wordcount + try runTest(filterName: "wordcount", input: "foo bar baz", expected: 3) + try runTest(filterName: "wordcount", input: "foo bar baz", expected: 3) + + // TODO: Test wordwrap + + // TODO: Test xmlattr + + // TODO: Fix key order in results using OrderedDictionary as input + // Test tojson + // try runFilterTest( + // filterName: "tojson", + // input: ["foo": 42, "bar": 23], + // expected: "{\n \"foo\": 42,\n \"bar\": 23\n}" + // ) + // try runFilterTest( + // filterName: "tojson", + // input: ["foo": 42, "bar": 23], + // args: [["indent": 4]], + // expected: "{\n \"foo\": 42,\n \"bar\": 23\n}" + // ) + } +} diff --git a/Tests/InterpreterTests.swift b/Tests/InterpreterTests.swift index d402f84..631d2e6 100644 --- a/Tests/InterpreterTests.swift +++ b/Tests/InterpreterTests.swift @@ -141,17 +141,18 @@ final class InterpreterTests: XCTestCase { for test in tests { let env = Environment() try env.set(name: "True", value: true) - for (key, value) in test.data { try env.set(name: key, value: value) } - let tokens = try tokenize(test.template, options: test.options) let parsed = try parse(tokens: tokens) let interpreter = Interpreter(env: env) - let result = try interpreter.run(program: parsed).value as! String - - XCTAssertEqual(result.debugDescription, test.target.debugDescription) + let result = try interpreter.run(program: parsed) + if let stringResult = result as? StringValue { + XCTAssertEqual(stringResult.value.debugDescription, test.target.debugDescription) + } else { + XCTFail("Expected a StringValue, but got \(type(of: result))") + } } } } diff --git a/Tests/Templates/ChatTemplateTests.swift b/Tests/Templates/ChatTemplateTests.swift new file mode 100644 index 0000000..ec1d704 --- /dev/null +++ b/Tests/Templates/ChatTemplateTests.swift @@ -0,0 +1,622 @@ +// +// ChatTemplateTests.swift +// +// +// Created by John Mai on 2024/3/24. +// + +import XCTest + +@testable import Jinja + +final class ChatTemplateTests: XCTestCase { + let messages: [[String: String]] = [ + [ + "role": "user", + "content": "Hello, how are you?", + ], + [ + "role": "assistant", + "content": "I'm doing great. How can I help you today?", + ], + [ + "role": "user", + "content": "I'd like to show off how chat templating works!", + ], + ] + + lazy var messagesWithSystemPrompt: [[String: String]] = + [ + [ + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + ] + ] + messages + + func testGenericChatTemplate() throws { + let chatTemplate = + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": false, + ]) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + XCTAssertEqual(result, target) + } + + func testFacebookBlenderbot400MDistill() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testFacebookBlenderbotSmall90M() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testBigscienceBloom() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "", + ]) + let target = + "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testEleutherAIGptNeox20b() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + XCTAssertEqual(result, target) + } + + func testGPT2() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer1() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt, + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer2() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nDEFAULT_SYSTEM_MESSAGE\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + XCTAssertEqual(result, target) + } + + func testHfInternalTestingLlamaTokenizer3() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": "<>\nYou are a helpful assistant\n<> Hello, how are you?", + ], + [ + "role": "assistant", + "content": "I'm doing great. How can I help you today?", + ], + [ + "role": "user", + "content": "I'd like to show off how chat templating works!", + ], + ], + "bos_token": "", + "eos_token": "", + "USE_DEFAULT_PROMPT": true, + ]) + let target = + "[INST] <>\nYou are a helpful assistant\n<> Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + XCTAssertEqual(result, target) + } + + func testOpenaiWhisperLargeV3() throws { + let chatTemplate = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "eos_token": "<|endoftext|>", + ]) + let target = + "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat1() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": true, + ]) + let target = + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat2() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt, + "add_generation_prompt": true, + ]) + let target = + "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\n" + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_1_8BChat3() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt + ]) + let target = + "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testTHUDMChatglm36b() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithSystemPrompt + ]) + let target = + "[gMASK]sop<|system|>\n You are a friendly chatbot who always responds in the style of a pirate<|user|>\n Hello, how are you?<|assistant|>\n I\'m doing great. How can I help you today?<|user|>\n I\'d like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testGoogleGemma2bIt() throws { + let chatTemplate = + "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages + ]) + let target = + "user\nHello, how are you?\nmodel\nI\'m doing great. How can I help you today?\nuser\nI\'d like to show off how chat templating works!\n" + XCTAssertEqual(result, target) + } + + func testQwenQwen2_5_0_5BInstruct() throws { + let chatTemplate = + "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages + ]) + let target = + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bBetaAddGenerationPromptFalse() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messagesWithSystemPrompt, "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n" + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bBetaAddGenerationPromptTrue() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": [ + [ + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + ], + ["role": "user", "content": "How many helicopters can a human eat in one sitting?"], + ], "eos_token": "", "add_generation_prompt": true, + ] as [String: Any] + ) + let target = + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHow many helicopters can a human eat in one sitting?\n<|assistant|>\n" + XCTAssertEqual(result, target) + } + + func testHuggingFaceH4Zephyr7bGemmaV0_1() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + XCTAssertEqual(result, target) + } + + func testTheBlokeMistral7BInstructV0_1GPTQ() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + XCTAssertEqual(result, target) + } + + func testMistralaiMixtral8x7BInstructV0_1() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]" + XCTAssertEqual(result, target) + } + + func testCognitivecomputationsDolphin2_5Mixtral8x7b() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + XCTAssertEqual(result, target) + } + + func testOpenchatOpenchat3_5_0106() throws { + let chatTemplate = + "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any] + ) + let target = + "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>" + XCTAssertEqual(result, target) + } + + func testUpstageSOLAR10_7BInstructV1_0() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "### User:\nHello, how are you?\n\n### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!\n\n" + XCTAssertEqual(result, target) + } + + func testCodellamaCodeLlama70bInstructHf() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}"; + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n " + XCTAssertEqual(result, target) + } + + func testDeciDeciLM7BInstruct() throws { + let chatTemplate = + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "### User:\nHello, how are you?\n### Assistant:\nI'm doing great. How can I help you today?\n### User:\nI'd like to show off how chat templating works!\n" + XCTAssertEqual(result, target) + } + + func testQwenQwen1_5_72BChat() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + XCTAssertEqual(result, target) + } + + func testDeepseekAiDeepseekLlm7bChat() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "<|begin of sentence|>", + "eos_token": "<|end of sentence|>", + ] as [String: Any] + ) + let target = + "<|begin of sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end of sentence|>User: I'd like to show off how chat templating works!\n\n" + XCTAssertEqual(result, target) + } + + func testH2oaiH2oDanube1_8bChat() throws { + let chatTemplate = + "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "eos_token": "", + ] as [String: Any] + ) + let target = + "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testInternlmInternlm2Chat7b() throws { + let chatTemplate = + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + XCTAssertEqual(result, target) + } + + func testTheBlokedeepseekCoder33BInstructAWQ() throws { + let chatTemplate = + "{%- set found_item = false -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set found_item = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n" + XCTAssertEqual(result, target) + } + + func testEriczzzFalconRw1bChat() throws { + let chatTemplate = + "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'].strip() }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ '[RESP] ' + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "eos_token": "<|endoftext|>", + ] as [String: Any] + ) + let target = + "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testAbacusaiSmaug34BV0_1() throws { + let chatTemplate = + "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "", "eos_token": "", + ] as [String: Any] + ) + let target = + "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + XCTAssertEqual(result, target) + } + + func testMaywellSynatraMixtral8x7B() throws { + let chatTemplate = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages + ] as [String: Any] + ) + let target = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testDeepseekAiDeepseekCoder33bInstruct() throws { + let chatTemplate = + "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messages, "bos_token": "<|begin of sentence|>", "eos_token": "<|EOT|>", + ] as [String: Any] + ) + let target = + "<|begin of sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n" + XCTAssertEqual(result, target) + } + + func testMaywellSynatraMixtral8x7B_2() throws { + let chatTemplate = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}" + let template = try Template(chatTemplate) + let result = try template.render( + [ + "messages": messagesWithSystemPrompt + ] as [String: Any] + ) + let target = + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\nYou are a friendly chatbot who always responds in the style of a pirate### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + XCTAssertEqual(result, target) + } + + func testMistralNemoInstruct2407() throws { + let chatTemplate = + "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{%- for message in loop_messages | rejectattr(\"role\", \"equalto\", \"tool\") | rejectattr(\"role\", \"equalto\", \"tool_results\") | selectattr(\"tool_calls\", \"undefined\") %}\n {%- if (message[\"role\"] == \"user\") != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message[\"role\"] == \"tool_calls\" or message.tool_calls is defined %}\n {%- if message.tool_calls is defined %}\n {%- set tool_calls = message.tool_calls %}\n {%- else %}\n {%- set tool_calls = message.content %}\n {%- endif %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messages, + "bos_token": "", + "eos_token": "", + ]) + let target = + "[INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?[INST]I'd like to show off how chat templating works![/INST]" + + XCTAssertEqual(result, target) + } + + func testQwen2VLTextOnly() throws { + let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": messages, + "add_generation_prompt": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + Hello, how are you?<|im_end|> + <|im_start|>assistant + I'm doing great. How can I help you today?<|im_end|> + <|im_start|>user + I'd like to show off how chat templating works!<|im_end|> + <|im_start|>assistant + + """ + XCTAssertEqual(result, target) + } + + func testPhi4() throws { + let userMessage = [ + "role": "user", + "content": "What is the weather in Paris today?", + ] + let chatTemplate = """ + {% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %} + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [userMessage], + "bos_token": "<|begin_of_text|>", + "add_generation_prompt": true, + ]) + let target = """ + <|im_start|>user<|im_sep|>What is the weather in Paris today?<|im_end|><|im_start|>assistant<|im_sep|> + """ + XCTAssertEqual(result, target) + } + + func testDeepSeekQwen() throws { + let userMessage = [ + "role": "user", + "content": "What is the weather in Paris today?", + ] + let chatTemplate = """ + {% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %} + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [userMessage], + "bos_token": "<|begin_of_text|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|><|User|>What is the weather in Paris today?<|Assistant|> + """ + XCTAssertEqual(result, target) + } +} diff --git a/Tests/Templates/ChatTemplates.swift b/Tests/Templates/ChatTemplates.swift new file mode 100644 index 0000000..f7e72e1 --- /dev/null +++ b/Tests/Templates/ChatTemplates.swift @@ -0,0 +1,21 @@ +// +// ChatTemplates.swift +// Jinja +// +// Created by Anthony DePasquale on 02.01.2025. +// + +struct ChatTemplate { + static let llama3_1 = """ + {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + """ + static let llama3_2 = """ + {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + """ + static let qwen2_5 = """ + {%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n + """ + static let mistral7b = """ + {{bos_token}}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{'[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]'}}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% elif message['role'] == 'tool_results' %}{{'[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]'}}{% elif message['role'] == 'tool_calls' %}{{'[TOOL_CALLS]' + message['content']|string + eos_token}}{% endif %}{% endfor %} + """ +} diff --git a/Tests/Templates/Messages.swift b/Tests/Templates/Messages.swift new file mode 100644 index 0000000..2159be3 --- /dev/null +++ b/Tests/Templates/Messages.swift @@ -0,0 +1,15 @@ +// +// Messages.swift +// Jinja +// +// Created by Anthony DePasquale on 02.01.2025. +// + +struct Messages { + static let weatherQuery: [[String: String]] = [ + [ + "role": "user", + "content": "What is the weather in Paris today?", + ] + ] +} diff --git a/Tests/Templates/ToolSpecs.swift b/Tests/Templates/ToolSpecs.swift new file mode 100644 index 0000000..daadbce --- /dev/null +++ b/Tests/Templates/ToolSpecs.swift @@ -0,0 +1,48 @@ +// +// ToolSpecs.swift +// Jinja +// +// Created by Anthony DePasquale on 02.01.2025. +// + +import OrderedCollections + +struct ToolSpec { + static let getCurrentWeather = OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather") as (String, Any), + ("description", "Get the current weather in a given location") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The city and state, e.g. San Francisco, CA") + as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) +} diff --git a/Tests/Templates/ToolUseTests.swift b/Tests/Templates/ToolUseTests.swift new file mode 100644 index 0000000..6b462b0 --- /dev/null +++ b/Tests/Templates/ToolUseTests.swift @@ -0,0 +1,685 @@ +// +// VisionTests.swift +// Jinja +// +// Created by Anthony DePasquale on 30.12.2024. +// + +import XCTest +import OrderedCollections + +/* + Recent models that don't support tool use: + - Gemma 2 + - Phi 3.5 + - Mistral NeMo + */ + +@testable import Jinja + +final class ToolUseTests: XCTestCase { + let messagesWithFunctionCalling: [[String: Any?]] = [ + [ + "role": "assistant", + "content": nil, + "tool_calls": [ + [ + "type": "function", + "function": [ + "name": "get_current_weather", + "arguments": "{\n \"location\": \"Hanoi\"\n}", + ], + ] + ], + ], + [ + "role": "user", + "content": "What's the weather like in Hanoi?", + ], + ] + + // Example adapted from https://huggingface.co/fireworks-ai/firefunction-v1 + let exampleFunctionSpec: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_stock_price") as (String, Any), + ("description", "Get the current stock price") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "symbol", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The stock symbol, e.g. AAPL, GOOG") as (String, Any), + ]) + ) + ]) + ) as (String, Any), + ("required", ["symbol"]) as (String, Any), + ]) + ) as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "check_word_anagram") as (String, Any), + ("description", "Check if two words are anagrams of each other") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "word1", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The first word") as (String, Any), + ]) + ) as (String, Any), + ( + "word2", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The second word") as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["word1", "word2"]) as (String, Any), + ]) + ) as (String, Any), + ]), + ] + + lazy var messagesWithFunctionCallingAndSystemPrompt: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "system") as (String, Any), + ("content", "You are a helpful assistant with access to functions. Use them if required.") as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "functions") as (String, Any), + ("content", exampleFunctionSpec) as (String, Any), + ]), + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "user") as (String, Any), + ("content", "Hi, can you tell me the current stock price of AAPL?") as (String, Any), + ]), + ] + + let exampleToolJSONSchemas: OrderedDictionary> = OrderedDictionary( + uniqueKeysWithValues: [ + ( + "get_current_weather", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather") as (String, Any), + ("description", "Get the current weather in a given location") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The city and state, e.g. San Francisco, CA") + as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_temperature_v1", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_temperature") as (String, Any), + ("description", "Get the current temperature at a location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ( + "description", + "The current temperature at the specified location in the specified units, as a float." + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_temperature_v2", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_temperature") as (String, Any), + ("description", "Get the current temperature at a location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("enum", ["celsius", "fahrenheit"]) as (String, Any), + ("description", "The unit to return the temperature in.") + as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ("required", ["location", "unit"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ( + "description", + "The current temperature at the specified location in the specified units, as a float." + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ( + "get_current_wind_speed", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_wind_speed") as (String, Any), + ("description", "Get the current wind speed in km/h at a given location.") as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ( + "description", + "The location to get the temperature for, in the format \"City, Country\"" + ) as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["location"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "number") as (String, Any), + ("description", "The current wind speed at the given location in km/h, as a float.") + as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ), + ]) + + lazy var exampleListOfTools: [OrderedDictionary] = [ + exampleToolJSONSchemas["get_current_temperature_v2"]!, + exampleToolJSONSchemas["get_current_wind_speed"]!, + ] + + func testMeetKaiFunctionaryMediumV2_2() throws { + let chatTemplate = """ + {#v2.2#}\n{% for message in messages %}\n{% if message['role'] == 'user' or message['role'] == 'system' %}\n{{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% elif message['role'] == 'tool' %}\n{{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% else %}\n{% set contain_content='no'%}\n{% if message['content'] is not none %}\n{{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}{% set contain_content='yes'%}\n{% endif %}\n{% if 'tool_calls' in message and message['tool_calls'] is not none %}\n{% for tool_call in message['tool_calls'] %}\n{% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}\n{% if loop.index == 1 and contain_content == "no" %}\n{{ prompt }}{% else %}\n{{ '\n' + prompt}}{% endif %}\n{% endfor %}\n{% endif %}\n{{ '<|stop|>\n' }}{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %} + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithFunctionCalling, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = + """ + <|from|>assistant\n<|recipient|>get_current_weather\n<|content|>{\n "location": "Hanoi"\n}<|stop|>\n<|from|>user\n<|recipient|>all\n<|content|>What's the weather like in Hanoi?\n + """ + XCTAssertEqual(result, target) + } + + func testFireworksAIFireFunctionV1() throws { + let chatTemplate = """ + {%- set message_roles = ['SYSTEM', 'FUNCTIONS', 'USER', 'ASSISTANT', 'TOOL'] -%}\n{%- set ns = namespace(seen_non_system=false, messages=messages, content='', functions=[]) -%}\n{{ bos_token }}\n{#- Basic consistency checks -#}\n{%- if not ns.messages -%}\n {{ raise_exception('No messages') }}\n{%- endif -%}\n{%- if ns.messages[0]['role'] | upper != 'SYSTEM' -%}\n {%- set ns.messages = [{'role': 'SYSTEM', 'content': 'You are a helpful assistant with access to functions. Use them if required.'}] + ns.messages -%}\n{%- endif -%}\n{%- if ns.messages | length < 2 or ns.messages[0]['role'] | upper != 'SYSTEM' or ns.messages[1]['role'] | upper != 'FUNCTIONS' -%}\n {{ raise_exception('Expected either "functions" or ["system", "functions"] as the first messages') }}\n{%- endif -%}\n{%- for message in ns.messages -%}\n {%- set role = message['role'] | upper -%}\n {#- Validation -#}\n {%- if role not in message_roles -%}\n {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles + ' are supported.') }}\n {%- endif -%}\n {%- set ns.content = message['content'] if message.get('content') else '' -%}\n {#- Move tool calls inside the content -#}\n {%- if 'tool_calls' in message -%}\n {%- for call in message['tool_calls'] -%}\n {%- set ns.content = ns.content + '{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}' -%}\n {%- endfor -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' and '' not in ns.content -%}\n {%- set ns.content = '' + ns.content -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' -%}\n {%- set ns.content = ns.content + eos_token -%}\n {%- endif -%}\n {{ role }}: {{ ns.content }}{{ '\\n\\n' }}\n{%- endfor -%}\nASSISTANT:{{ ' ' }}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": messagesWithFunctionCallingAndSystemPrompt, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = """ + SYSTEM: You are a helpful assistant with access to functions. Use them if required. + + FUNCTIONS: [{"name": "get_stock_price", "description": "Get the current stock price", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "The stock symbol, e.g. AAPL, GOOG"}}, "required": ["symbol"]}}, {"name": "check_word_anagram", "description": "Check if two words are anagrams of each other", "parameters": {"type": "object", "properties": {"word1": {"type": "string", "description": "The first word"}, "word2": {"type": "string", "description": "The second word"}}, "required": ["word1", "word2"]}}] + + USER: Hi, can you tell me the current stock price of AAPL? + + ASSISTANT: + """ + XCTAssertEqual(result, target) + } + + // Fails because tools are omitted in the output, and the result is indented. + // func testMistral7BInstructV0_3JSONSchema() throws { + // let chatTemplate = + // "{{- bos_token }}\n{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {%- if tools and (message == user_messages[-1]) %}\n {{- ' [AVAILABLE_TOOLS] [' }}\n {%- for tool in tools %}\n\t\t{%- set tool = tool.function %}\n\t\t{{- '{\"type\": \"function\", \"function\": {' }}\n\t\t{%- for key, val in tool|items if key != \"return\" %}\n\t\t {%- if val is string %}\n\t\t\t{{- '\"' + key + '\": \"' + val + '\"' }}\n\t\t {%- else %}\n\t\t\t{{- '\"' + key + '\": ' + val|tojson }}\n\t\t {%- endif %}\n\t\t {%- if not loop.last %}\n\t\t\t{{- \", \" }}\n\t\t {%- endif %}\n\t\t{%- endfor %}\n\t\t{{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- ' [/AVAILABLE_TOOLS]' }}\n {%- endif %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- elif message['role'] == 'assistant' %}\n {%- if message.tool_calls is defined and message.tool_calls|length > 0 %}\n {{- ' [TOOL_CALLS] [' }}\n {%- for tool_call in message.tool_calls %}\n {{- {\"name\": tool_call.function.name, \"arguments\": tool_call.function.arguments, \"id\": tool_call.id}|tojson }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- '] ' }}\n {{- eos_token }}\n \t{%- elif message.content is defined %}\n\t {{- ' ' + message.content + ' ' + eos_token}}\n {%- endif %}\n {%- elif message['role'] == 'tool' %}\n {{- ' [TOOL_RESULTS] ' }}\n {{- '{\"call_id\": \"' + message.tool_call_id + '\", \"content\": ' + message.content|string + '}' }}\n {{- ' [/TOOL_RESULTS] ' }}\n {%- endif %}\n{%- endfor %}\n" + // let template = try Template(chatTemplate) + // + // let result = try template.render([ + // "messages": [ + // [ + // "role": "system", + // "content": + // "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.", + // ], + // ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + // [ + // "role": "assistant", + // "tool_calls": [ + // [ + // "id": "abcdef123", + // "type": "function", + // "function": [ + // "name": "get_current_temperature", + // "arguments": ["location": "Paris, France", "unit": "celsius"], + // ], + // ] + // ], + // ], + // ["role": "tool", "tool_call_id": "abcdef123", "name": "get_current_temperature", "content": "22.0"], + // ], + // "tools": exampleListOfTools, + // // "tools_json": "", // TODO: Figure out how to convert the array of OrderedDictionaries to JSON + // "bos_token": "", + // "eos_token": "", + // ]) + // let target = """ + // [AVAILABLE_TOOLS] [{"type": "function", "function": {"name": "get_current_temperature", "description": "Get the current temperature at a location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The unit to return the temperature in."}}, "required": ["location", "unit"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}}, "required": ["location"]}}}] [/AVAILABLE_TOOLS] [INST] Hey, what\'s the temperature in Paris right now? [/INST] [TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "abcdef123"}] [TOOL_RESULTS] {"call_id": "abcdef123", "content": 22.0} [/TOOL_RESULTS] + // """ + // + // XCTAssertEqual(result, target) + // } + + // Previously failed because tools are omitted in the output, now fails because of error with `map`: runtime("map filter requires either an attribute name or a function") + // func testCISCaiMistral7BInstructV0_3SOTAGGUF() throws { + // let chatTemplate = """ + // {{ bos_token }}{% set ns = namespace(lastuser=-1, system=false, functions=false) %}{% if tools %}{% for message in messages %}{% if message['role'] == 'user' %}{% set ns.lastuser = loop.index0 %}{% elif message['role'] == 'system' %}{% set ns.system = message['content'] %}{% endif %}{% endfor %}{% set ns.functions = tools|selectattr('type','eq','function')|map(attribute='function')|list|tojson %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{% if loop.index0 == ns.lastuser and ns.functions %}{{ '[AVAILABLE_TOOLS] ' }}{{ ns.functions }}{{ '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST] ' }}{% if loop.index0 == ns.lastuser and ns.system %}{{ ns.system + ' ' }}{% endif %}{{ message['content'] }}{{ '[/INST]' }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS] ' }}{{ dict(call_id=message['tool_call_id'], content=message['content'])|tojson }}{{ '[/TOOL_RESULTS]' }}{% elif message['role'] == 'assistant' %}{% if message['tool_calls'] %}{{ '[TOOL_CALLS] [' }}{% for call in message['tool_calls'] %}{% if call['type'] == 'function' %}{{ dict(id=call['id'], name=call['function']['name'], arguments=call['function']['arguments'])|tojson }}{% endif %}{% if not loop.last %}{{ ', ' }}{% endif %}{% endfor %}{{ ']' }}{% else %}{{ message['content'] }}{% endif %}{{ eos_token }}{% endif %}{% endfor %} + // """ + // let template = try Template(chatTemplate) + // + // let result = try template.render([ + // "messages": [ + // [ + // "role": "user", + // "content": "What's the weather like in Oslo and Stockholm?", + // ] + // ], + // "tools": [exampleToolJSONSchemas["get_current_temperature_v2"]!], + // "bos_token": "", + // "eos_token": "", + // ]) + // let target = + // """ + // [AVAILABLE_TOOLS] [{"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}][/AVAILABLE_TOOLS][INST] What's the weather like in Oslo and Stockholm?[/INST] + // """ + // + // XCTAssertEqual(result, target) + // } + + func testNousResearchHermes2ProLlama38BJSONSchema() throws { + let chatTemplate = """ + {%- macro json_to_python_type(json_spec) %}\n{%- set basic_type_map = {\n "string": "str",\n "number": "float",\n "integer": "int",\n "boolean": "bool"\n} %}\n\n{%- if basic_type_map[json_spec.type] is defined %}\n {{- basic_type_map[json_spec.type] }}\n{%- elif json_spec.type == "array" %}\n {{- "list[" + json_to_python_type(json_spec|items) + "]"}}\n{%- elif json_spec.type == "object" %}\n {%- if json_spec.additionalProperties is defined %}\n {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}}\n {%- else %}\n {{- "dict" }}\n {%- endif %}\n{%- elif json_spec.type is iterable %}\n {{- "Union[" }}\n {%- for t in json_spec.type %}\n {{- json_to_python_type({"type": t}) }}\n {%- if not loop.last %}\n {{- "," }} \n {%- endif %}\n {%- endfor %}\n {{- "]" }}\n{%- else %}\n {{- "Any" }}\n{%- endif %}\n{%- endmacro %}\n\n\n{{- bos_token }}\n{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }}\n{%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- '{"type": "function", "function": ' }}\n {{- '{"name": ' + tool.name + '", ' }}\n {{- '"description": "' + tool.name + '(' }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- param_name + ": " + json_to_python_type(param_fields) }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- if tool.return is defined %}\n {{- " -> " + json_to_python_type(tool.return) }}\n {%- endif %}\n {{- " - " + tool.description + "\\n\\n" }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {%- if loop.first %}\n {{- " Args:\\n" }}\n {%- endif %}\n {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}\n {%- endfor %}\n {%- if tool.return is defined and tool.return.description is defined %}\n {{- "\\n Returns:\\n " + tool.return.description }}\n {%- endif %}\n {{- '"' }}\n {{- ', "parameters": ' }}\n {%- if tool.parameters.properties | length == 0 %}\n {{- "{}" }}\n {%- else %}\n {{- tool.parameters | tojson}}\n {%- endif %}\n {{- "}" }}\n {%- if not loop.last %}\n {{- "\\n" }}\n {%- endif %}\n{%- endfor %}\n{{- " " }}\n{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\n' }}\n{{- "For each function call return a json object with function name and arguments within XML tags as follows:\n" }}\n{{- "\n" }}\n{{- '{"arguments": , "name": }\n' }}\n{{- '<|im_end|>' }}\n{%- for message in messages %}\n {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == "assistant" %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '{ ' }}\n {%- if tool_call.arguments is defined %}\n {{- '"arguments": ' }}\n {{- tool_call.arguments|tojson }}\n {{- ', '}}\n {%- endif %}\n {{- '"name": "' }}\n {{- tool_call.name }}\n {{- '"}' }}\n {{- '\\n ' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == "tool" %}\n {%- if not message.name is defined %}\n {{- raise_exception("Tool response dicts require a 'name' key indicating the name of the called function!") }}\n {%- endif %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {{- '{"name": "' }}\n {{- message.name }}\n {{- '", "content": ' }}\n {{- message.content|tojson + '}' }}\n {{- '\\n <|im_end|>\\n' }} \n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + OrderedDictionary(uniqueKeysWithValues: [ + ("role", "user") as (String, Any), + ("content", "Fetch the stock fundamentals data for Tesla (TSLA)") as (String, Any), + ]) + ], + "tools": [ + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function") as (String, Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_stock_fundamentals") as (String, Any), + ("description", "Get fundamental data for a given stock symbol using yfinance API.") + as (String, Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "symbol", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string") as (String, Any), + ("description", "The stock symbol.") as (String, Any), + ]) + ) as (String, Any) + ]) + ) as (String, Any), + ("required", ["symbol"]) as (String, Any), + ]) + ) as (String, Any), + ( + "return", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object") as (String, Any), + ( + "description", + """ + A dictionary containing fundamental data. + + Keys: + - 'symbol': The stock symbol. + - 'company_name': The long name of the company. + - 'sector': The sector to which the company belongs. + - 'industry': The industry to which the company belongs. + - 'market_cap': The market capitalization of the company. + - 'pe_ratio': The forward price-to-earnings ratio. + - 'pb_ratio': The price-to-book ratio. + - 'dividend_yield': The dividend yield. + - 'eps': The trailing earnings per share. + - 'beta': The beta value of the stock. + - '52_week_high': The 52-week high price of the stock. + - '52_week_low': The 52-week low price of the stock. + """ + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ) as (String, Any), + ]) + ], + "bos_token": "<|begin_of_text|>", + "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|>You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\n\n Args:\n symbol(str): The stock symbol.\n Returns:\n A dictionary containing fundamental data.\n\nKeys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "The stock symbol."}}, "required": ["symbol"]}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\nFor each function call return a json object with function name and arguments within XML tags as follows:\n\n{"arguments": , "name": }\n<|im_end|><|im_start|>user\nFetch the stock fundamentals data for Tesla (TSLA)<|im_end|>\n<|im_start|>assistant\n + """ + XCTAssertEqual(result, target) + } + + // func testMetaLlamaLlama3_18BInstruct() throws { + // let chatTemplate = """ + // {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = "26 Jul 2024" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\\n\\n"}}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- "<|python_tag|>" + tool_call.name + ".call(" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '="' + arg_val + '"' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{"name": "' + tool_call.name + '", ' }}\n {{- '"parameters": ' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- "<|eom_id|>" }}\n {%- else %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + // """ + // let template = try Template(chatTemplate) + // let result = try template.render([ + // "messages": [ + // ["role": "system", "content": "You are a bot that responds to weather queries."], + // ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + // ], + // "tools": [exampleToolJSONSchemas["get_current_temperature_v1"]!], + // "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + // "add_generation_prompt": true, + // ]) + // let target = """ + // <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables.\n\n{\n "type": "function",\n "function": {\n "name": "get_current_temperature",\n "description": "Get the current temperature at a location.",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The location to get the temperature for, in the format \\"City, Country\\""\n }\n },\n "required": [\n "location"\n ]\n },\n "return": {\n "type": "number",\n "description": "The current temperature at the specified location in the specified units, as a float."\n }\n }\n}\n\nHey, what's the temperature in Paris right now?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n + // """ + // XCTAssertEqual(result, target) + // } + + // + + func testLlama3_1() throws { + let chatTemplate = ChatTemplate.llama3_1 + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + Environment: ipython + Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + <|eot_id|><|start_header_id|>user<|end_header_id|> + + Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + + Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + } + + What is the weather in Paris today?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + + """ + XCTAssertEqual(result, target) + } + + func testLlama3_2() throws { + let chatTemplate = ChatTemplate.llama3_2 + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + Environment: ipython + Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + <|eot_id|><|start_header_id|>user<|end_header_id|> + + Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + + Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + } + + What is the weather in Paris today?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + + """ + XCTAssertEqual(result, target) + } + + func testQwen2_5() throws { + let chatTemplate = ChatTemplate.qwen2_5 + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|im_start|>system + You are Qwen, created by Alibaba Cloud. You are a helpful assistant. + + # Tools + + You may call one or more functions to assist with the user query. + + You are provided with function signatures within XML tags: + + {"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}} + + + For each function call, return a json object with function name and arguments within XML tags: + + {"name": , "arguments": } + <|im_end|> + <|im_start|>user + What is the weather in Paris today?<|im_end|> + <|im_start|>assistant + + """ + XCTAssertEqual(result, target) + } + + func testMistral7b() throws { + let chatTemplate = ChatTemplate.mistral7b + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": Messages.weatherQuery, + "tools": [ToolSpec.getCurrentWeather], + "bos_token": "<|begin_of_text|>", + // "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]What is the weather in Paris today?[/INST] + """ + XCTAssertEqual(result, target) + } +} + +extension Data { + var string: String? { + return String(data: self, encoding: .utf8) + } +} diff --git a/Tests/Templates/VisionTests.swift b/Tests/Templates/VisionTests.swift new file mode 100644 index 0000000..76593ce --- /dev/null +++ b/Tests/Templates/VisionTests.swift @@ -0,0 +1,257 @@ +// +// VisionTests.swift +// Jinja +// +// Created by Anthony DePasquale on 31.12.2024. +// + +import XCTest +import OrderedCollections + +@testable import Jinja + +final class VisionTests: XCTestCase { + let llama3_2visionChatTemplate = + "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == \"\" %}\n {{- raise_exception(\"Prompting with images is incompatible with system messages.\") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + + func testLlama3_2_11BVisionInstructTextChatOnly() throws { + let template = try Template(llama3_2visionChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "Hello, how are you?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "assistant", + "content": [ + [ + "type": "text", + "text": "I'm doing great. How can I help you today?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "I'd like to show off how chat templating works!", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "date_string": "26 Jul 2024" as Any, + "tools_in_user_message": true as Any, + "system_message": "You are a helpful assistant." as Any, + "add_generation_prompt": true as Any, + ]) + let target = + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + XCTAssertEqual(result, target) + } + + func testLlama3_2_11BVisionInstructWithImages() throws { + let template = try Template(llama3_2visionChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: Any], + [ + "type": "image", + "image": "base64_encoded_image_data", + ] as [String: Any], + ] as [[String: Any]], + ] as [String: Any] + ] as [[String: Any]], + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + ]) + let target = + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat's in this image?<|image|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + XCTAssertEqual(result, target) + } + + func testQwen2VLWithImages() throws { + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: String], + [ + "type": "image", + "image_url": "example.jpg", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's in this image?Picture 1: <|vision_start|><|image_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + XCTAssertEqual(result, target) + } + + func testQwen2VLWithVideo() throws { + let template = try Template(qwen2VLChatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's happening in this video?", + ] as [String: String], + [ + "type": "video", + "video_url": "example.mp4", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ]) + let target = """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's happening in this video?Video 1: <|vision_start|><|video_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + XCTAssertEqual(result, target) + } + + func testLlama3_2_11BVisionInstructWithTools() throws { + let template = try Template(llama3_2visionChatTemplate) + + let tools: [OrderedDictionary] = [ + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "function" as Any), + ( + "function", + OrderedDictionary(uniqueKeysWithValues: [ + ("name", "get_current_weather" as Any), + ("description", "Get the current weather in a given location" as Any), + ( + "parameters", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "object" as Any), + ( + "properties", + OrderedDictionary(uniqueKeysWithValues: [ + ( + "location", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string" as Any), + ("description", "The city and state, e.g. San Francisco, CA" as Any), + ]) as Any + ), + ( + "unit", + OrderedDictionary(uniqueKeysWithValues: [ + ("type", "string" as Any), + ("enum", ["celsius", "fahrenheit"] as Any), + ]) as Any + ), + ]) as Any + ), + ("required", ["location"] as Any), + ]) as Any + ), + ]) as Any + ), + ]) + ] + + let result = try template.render([ + "messages": [ + [ + "role": "system", + "content": "You are a helpful assistant.", + ], + [ + "role": "user", + "content": "What's the weather like in San Francisco?", + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + "tools": tools as Any, + "tools_in_user_message": true as Any, + ]) + let target = """ + + <|start_header_id|>system<|end_header_id|> + + Environment: ipython + Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> + + Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. + + Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables. + + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } + }, + "required": [ + "location" + ] + } + } + } + + What's the weather like in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + + """ + XCTAssertEqual(result, target) + } +} diff --git a/Tests/TestTests.swift b/Tests/TestTests.swift new file mode 100644 index 0000000..212598f --- /dev/null +++ b/Tests/TestTests.swift @@ -0,0 +1,150 @@ +// +// TestTests.swift +// Jinja +// +// Created by Anthony DePasquale on 07.01.2025. +// + +// Adapted from https://github.com/pallets/jinja/blob/main/tests/test_tests.py + +import XCTest +@testable import Jinja + +final class TestTests: XCTestCase { + func testTests() throws { + // Helper function to run tests + func runTest( + testName: String, + input: Any, + args: [Any?] = [], + expected: Bool, + file: StaticString = #file, + line: UInt = #line + ) throws { + let env = Environment() + + // Convert input to RuntimeValue + guard let input = try? env.convertToRuntimeValues(input: input) else { + XCTFail( + "Failed to convert input \(input) to RuntimeValue in test for \(testName)", + file: file, + line: line + ) + return + } + + // Convert args to RuntimeValues + let runtimeArgs = try args.map { arg -> any RuntimeValue in + if let arg = arg { + return try env.convertToRuntimeValues(input: arg) + } + return UndefinedValue() + } + + // Get the test function from the environment + guard let test = env.tests[testName] else { + XCTFail("Test not found: \(testName)", file: file, line: line) + return + } + + // Call the test function based on number of arguments + let result: Bool + switch runtimeArgs.count { + case 0: + result = try test(input) + case 1: + result = try test(input, runtimeArgs[0]) + case 2: + result = try test(input, runtimeArgs[0], runtimeArgs[1]) + case 3: + result = try test(input, runtimeArgs[0], runtimeArgs[1], runtimeArgs[2]) + default: + throw JinjaError.runtime("Unsupported number of arguments for test: \(testName)") + } + + XCTAssertEqual(result, expected, "\(testName) test failed", file: file, line: line) + } + + // Test defined + try runTest(testName: "defined", input: UndefinedValue(), expected: false) + try runTest(testName: "defined", input: true, expected: true) + + // Test even/odd + try runTest(testName: "even", input: 1, expected: false) + try runTest(testName: "even", input: 2, expected: true) + try runTest(testName: "odd", input: 1, expected: true) + try runTest(testName: "odd", input: 2, expected: false) + + // Test lower/upper + try runTest(testName: "lower", input: "foo", expected: true) + try runTest(testName: "lower", input: "FOO", expected: false) + try runTest(testName: "upper", input: "FOO", expected: true) + try runTest(testName: "upper", input: "foo", expected: false) + + // Test type checks + try runTest(testName: "none", input: NullValue(), expected: true) + try runTest(testName: "none", input: false, expected: false) + try runTest(testName: "none", input: true, expected: false) + try runTest(testName: "none", input: 42, expected: false) + + try runTest(testName: "boolean", input: false, expected: true) + try runTest(testName: "boolean", input: true, expected: true) + try runTest(testName: "boolean", input: 0, expected: false) + try runTest(testName: "boolean", input: 1, expected: false) + + try runTest(testName: "false", input: false, expected: true) + try runTest(testName: "false", input: true, expected: false) + try runTest(testName: "true", input: true, expected: true) + try runTest(testName: "true", input: false, expected: false) + + try runTest(testName: "integer", input: 42, expected: true) + try runTest(testName: "integer", input: 3.14159, expected: false) + try runTest(testName: "float", input: 3.14159, expected: true) + try runTest(testName: "float", input: 42, expected: false) + + try runTest(testName: "string", input: "foo", expected: true) + try runTest(testName: "string", input: 42, expected: false) + + try runTest(testName: "sequence", input: [1, 2, 3], expected: true) + try runTest(testName: "sequence", input: "foo", expected: true) + try runTest(testName: "sequence", input: 42, expected: false) + + try runTest(testName: "mapping", input: ["foo": "bar"], expected: true) + try runTest(testName: "mapping", input: [1, 2, 3], expected: false) + + try runTest(testName: "number", input: 42, expected: true) + try runTest(testName: "number", input: 3.14159, expected: true) + try runTest(testName: "number", input: "foo", expected: false) + + // Test equalto/eq + try runTest(testName: "eq", input: 12, args: [12], expected: true) + try runTest(testName: "eq", input: 12, args: [0], expected: false) + try runTest(testName: "eq", input: "baz", args: ["baz"], expected: true) + try runTest(testName: "eq", input: "baz", args: ["zab"], expected: false) + + // Test comparison aliases + try runTest(testName: "ne", input: 2, args: [3], expected: true) + try runTest(testName: "ne", input: 2, args: [2], expected: false) + try runTest(testName: "lt", input: 2, args: [3], expected: true) + try runTest(testName: "lt", input: 2, args: [2], expected: false) + try runTest(testName: "le", input: 2, args: [2], expected: true) + try runTest(testName: "le", input: 2, args: [1], expected: false) + try runTest(testName: "gt", input: 2, args: [1], expected: true) + try runTest(testName: "gt", input: 2, args: [2], expected: false) + try runTest(testName: "ge", input: 2, args: [2], expected: true) + try runTest(testName: "ge", input: 2, args: [3], expected: false) + + // Test in + try runTest(testName: "in", input: "o", args: [["f", "o", "o"]], expected: true) + try runTest(testName: "in", input: "foo", args: [["foo"]], expected: true) + try runTest(testName: "in", input: "b", args: [["f", "o", "o"]], expected: false) + try runTest(testName: "in", input: 1, args: [[1, 2]], expected: true) + try runTest(testName: "in", input: 3, args: [[1, 2]], expected: false) + + // Test filter/test existence + try runTest(testName: "filter", input: "title", expected: true) + try runTest(testName: "filter", input: "bad-name", expected: false) + try runTest(testName: "test", input: "number", expected: true) + try runTest(testName: "test", input: "bad-name", expected: false) + } +}