From 77afc64739fe2028b78c20ff5f9da570040a746a Mon Sep 17 00:00:00 2001 From: Jacopo Mangiavacchi Date: Wed, 6 Nov 2019 19:57:17 -0800 Subject: [PATCH] Removed non list in the Feature enum --- Sources/SwiftTFRecords/Feature.swift | 84 +++++++++++++------- Sources/SwiftTFRecords/Record.swift | 36 +-------- Tests/SwiftTFRecordsTests/FeatureTests.swift | 14 ++-- 3 files changed, 65 insertions(+), 69 deletions(-) diff --git a/Sources/SwiftTFRecords/Feature.swift b/Sources/SwiftTFRecords/Feature.swift index fea96bb..984823d 100644 --- a/Sources/SwiftTFRecords/Feature.swift +++ b/Sources/SwiftTFRecords/Feature.swift @@ -8,9 +8,6 @@ import Foundation public enum Feature { - case Float(_ value: Float) - case Int(_ value: Int) - case Bytes(_ value: Data) case FloatArray(_ value: [Float]) case IntArray(_ value: [Int]) case BytesArray(_ value: [Data]) @@ -24,78 +21,107 @@ public enum Feature { public static func StringArray(_ value: [Swift.String]) -> Self { return Feature.BytesArray(value.map{ Data(Swift.String("\($0)").utf8) }) } - - public func toFloat() -> Float? { + + public static func Float(_ value: Swift.Float) -> Self { + return Feature.FloatArray([value]) + } + + public static func Int(_ value: Swift.Int) -> Self { + return Feature.IntArray([value]) + } + + public static func Bytes(_ value: Data) -> Self { + return Feature.BytesArray([value]) + } + + public func toFloatArray() -> [Float]? { switch self { - case .Float(let value): + case .FloatArray(let value): return value + default: return nil } } - public func toInt() -> Int? { + public func toIntArray() -> [Int]? { switch self { - case .Int(let value): + case .IntArray(let value): return value + default: return nil } } - public func toBytes() -> Data? { + public func toBytesArray() -> [Data]? { switch self { - case .Bytes(let value): + case .BytesArray(let value): return value + default: return nil } } - public func toFloatArray() -> [Float]? { + public func toString() -> Swift.String? { switch self { - case .FloatArray(let value): - return value + case .BytesArray(let value): + if value.count == 1, let string = Swift.String(bytes: value[0], encoding: .utf8) { + return string + } + return nil + default: return nil } } - - public func toIntArray() -> [Int]? { + + public func toStringArray() -> [Swift.String]? { switch self { - case .IntArray(let value): - return value + case .BytesArray(let value): + let stringArray = value.compactMap{ Swift.String(bytes: $0, encoding: .utf8) } + return stringArray.isEmpty ? nil : stringArray + default: return nil } } - public func toBytesArray() -> [Data]? { + public func toFloat() -> Float? { switch self { - case .BytesArray(let value): - return value + case .FloatArray(let value): + if value.count == 1 { + return value[0] + } + return nil + default: return nil } } - - public func toString() -> Swift.String? { + + public func toInt() -> Int? { switch self { - case .Bytes(let value): - if let string = Swift.String(bytes: value, encoding: .utf8) { - return string + case .IntArray(let value): + if value.count == 1 { + return value[0] } - return nil + default: return nil } } - - public func toStringArray() -> [Swift.String]? { + + public func toBytes() -> Data? { switch self { case .BytesArray(let value): - return value.compactMap{ Swift.String(bytes: $0, encoding: .utf8) } + if value.count == 1 { + return value[0] + } + return nil + default: return nil } diff --git a/Sources/SwiftTFRecords/Record.swift b/Sources/SwiftTFRecords/Record.swift index f49d6ab..0160428 100644 --- a/Sources/SwiftTFRecords/Record.swift +++ b/Sources/SwiftTFRecords/Record.swift @@ -17,21 +17,6 @@ public struct Record { var tfFeature = Tfrecords_Feature() switch feature { - case let .Float(value): - var list = Tfrecords_FloatList() - list.value = [value] - tfFeature.floatList = list - - case let .Int(value): - var list = Tfrecords_Int64List() - list.value = [Int64(value)] - tfFeature.int64List = list - - case let .Bytes(value): - var list = Tfrecords_BytesList() - list.value = [value] - tfFeature.bytesList = list - case let .FloatArray(value): var list = Tfrecords_FloatList() list.value = value @@ -66,32 +51,17 @@ public struct Record { for (name, feature) in example.features.feature { switch feature.kind { case let .floatList(list): - switch list.value.count { - case 0: - break - case 1: - features[name] = Feature.Float(list.value[0]) - default: + if !list.value.isEmpty { features[name] = Feature.FloatArray(list.value) } case let .int64List(list): - switch list.value.count { - case 0: - break - case 1: - features[name] = Feature.Int(Int(list.value[0])) - default: + if !list.value.isEmpty { features[name] = Feature.IntArray(list.value.map { Int($0) }) } case let .bytesList(list): - switch list.value.count { - case 0: - break - case 1: - features[name] = Feature.Bytes(list.value[0]) - default: + if !list.value.isEmpty { features[name] = Feature.BytesArray(list.value) } diff --git a/Tests/SwiftTFRecordsTests/FeatureTests.swift b/Tests/SwiftTFRecordsTests/FeatureTests.swift index e5b38f6..4984359 100644 --- a/Tests/SwiftTFRecordsTests/FeatureTests.swift +++ b/Tests/SwiftTFRecordsTests/FeatureTests.swift @@ -6,11 +6,11 @@ final class FeatureTests: XCTestCase { let feature: Feature = 12.34 XCTAssertEqual(feature.toFloat(), 12.34) + XCTAssertEqual(feature.toFloatArray(), [12.34]) XCTAssertNil(feature.toInt()) XCTAssertNil(feature.toBytes()) XCTAssertNil(feature.toString()) XCTAssertNil(feature.toIntArray()) - XCTAssertNil(feature.toFloatArray()) XCTAssertNil(feature.toBytesArray()) XCTAssertNil(feature.toStringArray()) } @@ -19,10 +19,10 @@ final class FeatureTests: XCTestCase { let feature: Feature = 17 XCTAssertEqual(feature.toInt(), 17) + XCTAssertEqual(feature.toIntArray(), [17]) XCTAssertNil(feature.toFloat()) XCTAssertNil(feature.toBytes()) XCTAssertNil(feature.toString()) - XCTAssertNil(feature.toIntArray()) XCTAssertNil(feature.toFloatArray()) XCTAssertNil(feature.toBytesArray()) XCTAssertNil(feature.toStringArray()) @@ -32,12 +32,12 @@ final class FeatureTests: XCTestCase { let feature: Feature = Feature.Bytes(Data([0, 202, 255, 44, 5])) XCTAssertEqual(feature.toBytes(), Data([0, 202, 255, 44, 5])) + XCTAssertEqual(feature.toBytesArray(), [Data([0, 202, 255, 44, 5])]) XCTAssertNil(feature.toInt()) XCTAssertNil(feature.toFloat()) XCTAssertNil(feature.toString()) XCTAssertNil(feature.toIntArray()) XCTAssertNil(feature.toFloatArray()) - XCTAssertNil(feature.toBytesArray()) XCTAssertNil(feature.toStringArray()) } @@ -45,26 +45,26 @@ final class FeatureTests: XCTestCase { let feature: Feature = "Jacopo 😃" XCTAssertEqual(feature.toString(), "Jacopo 😃") + XCTAssertEqual(feature.toStringArray(), ["Jacopo 😃"]) XCTAssertEqual(feature.toBytes(), Data([74, 97, 99, 111, 112, 111, 32, 240, 159, 152, 131])) + XCTAssertEqual(feature.toBytesArray(), [Data([74, 97, 99, 111, 112, 111, 32, 240, 159, 152, 131])]) XCTAssertNil(feature.toInt()) XCTAssertNil(feature.toFloat()) XCTAssertNil(feature.toIntArray()) XCTAssertNil(feature.toFloatArray()) - XCTAssertNil(feature.toBytesArray()) - XCTAssertNil(feature.toStringArray()) } func testString2() { let feature: Feature = Feature.String("Jacopo 😃") XCTAssertEqual(feature.toString(), "Jacopo 😃") + XCTAssertEqual(feature.toStringArray(), ["Jacopo 😃"]) XCTAssertEqual(feature.toBytes(), Data([74, 97, 99, 111, 112, 111, 32, 240, 159, 152, 131])) + XCTAssertEqual(feature.toBytesArray(), [Data([74, 97, 99, 111, 112, 111, 32, 240, 159, 152, 131])]) XCTAssertNil(feature.toInt()) XCTAssertNil(feature.toFloat()) XCTAssertNil(feature.toIntArray()) XCTAssertNil(feature.toFloatArray()) - XCTAssertNil(feature.toBytesArray()) - XCTAssertNil(feature.toStringArray()) } func testFloatArray() {