diff --git a/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala b/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala index 7207cb50..d057004c 100644 --- a/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala +++ b/avrohugger-core/src/main/scala/format/specific/SpecificImporter.scala @@ -25,8 +25,7 @@ object SpecificImporter extends Importer { val switchAnnotSymbol = RootClass.newClass("scala.annotation.switch") val switchImport = IMPORT(switchAnnotSymbol) - val topLevelSchemas = - getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher) + val topLevelSchemas = getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher) val recordSchemas = getRecordSchemas(topLevelSchemas) val enumSchemas = getEnumSchemas(topLevelSchemas) val userDefinedDeps = getUserDefinedImports(recordSchemas ++ enumSchemas, currentNamespace, typeMatcher) @@ -42,7 +41,6 @@ object SpecificImporter extends Importer { else libraryDeps ++ userDefinedDeps } case Right(protocol) => { - val types = protocol.getTypes().asScala.toList val messages = protocol.getMessages.asScala.toMap if (messages.isEmpty) switchImport :: libraryDeps ::: userDefinedDeps // for ADT else List.empty // for RPC diff --git a/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala b/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala index 022290be..38602de7 100644 --- a/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala +++ b/avrohugger-core/src/main/scala/input/parsers/FileInputParser.scala @@ -19,7 +19,7 @@ import scala.util.Try class FileInputParser { var processedFiles: Set[String] = Set.empty - var processedSchemas: Set[Either[Schema, Protocol]] = Set.empty + var processedSchemas: Set[Schema] = Set.empty def getSchemaOrProtocols( infile: File, @@ -110,13 +110,7 @@ class FileInputParser { } } - def stripImports( - protocol: Protocol, - importedSchemaOrProtocols: Set[Either[Schema, Protocol]]) = { - val imported = importedSchemaOrProtocols.flatMap { - case Left(importedSchema) => List(importedSchema) - case Right(importedProtocol) => importedProtocol.getTypes().asScala - } + def stripImports(protocol: Protocol, imported: Set[Schema]) = { val types = protocol.getTypes().asScala.toList val localTypes = types.filterNot(imported.contains) protocol.setTypes(localTypes.asJava) @@ -132,7 +126,10 @@ class FileInputParser { |".avsc" for plain text json files, ".avdl" for IDL files, or .avro |for binary.""".trim.stripMargin) } - res.foreach(processedSchemas += _) + res.foreach { + case Left(importedSchema) => processedSchemas += importedSchema + case Right(importedProtocol) => processedSchemas ++= importedProtocol.getTypes().asScala + } res } } \ No newline at end of file diff --git a/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala b/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala index 0508a31e..61289dc0 100644 --- a/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala +++ b/avrohugger-core/src/main/scala/input/parsers/IdlImportParser.scala @@ -2,13 +2,11 @@ package avrohugger package input package parsers -import org.apache.avro.{ Protocol, Schema } - import java.io.File import scala.util.matching.Regex.Match object IdlImportParser { - + def stripComments(fileContents: String): String = { val multiLinePattern = """/\*.*\*/""".r val singleLinePattern = """//.*$""".r @@ -29,13 +27,14 @@ object IdlImportParser { // if file is empty, try again, it was there when we read idl if (fileContents.isEmpty && (count < maxTries)) readFile(infile) else fileContents - } catch {// if file is not found, try again, it was there when we read idl + } catch { // if file is not found, try again, it was there when we read idl case e: java.io.FileNotFoundException => { if (count < maxTries) readFile(infile) else sys.error("File to found: " + infile) } } } + val path = infile.getParent + "/" val contents = readFile(infile) val avdlPattern = """import[ \t]+idl[ \t]+"([^"]*\.avdl)"[ \t]*;""".r @@ -45,24 +44,23 @@ object IdlImportParser { val protocolMatches = avprPattern.findAllIn(contents).matchData.toList val schemaMatches = avscPattern.findAllIn(contents).matchData.toList val importMatches = idlMatches ::: protocolMatches ::: schemaMatches - + val (localImports, nonLocalMatches): (List[File], List[Match]) = - importMatches.foldLeft((List.empty[File], List.empty[Match])){ - case ((ai,am), m) => + importMatches.foldLeft((List.empty[File], List.empty[Match])) { + case ((ai, am), m) => val f = new File(path + m.group(1)) - if (f.exists) (ai:+f, am) - else (ai, am:+m) + if (f.exists) (ai :+ f, am) + else (ai, am :+ m) } - - val classpathImports: List[File] = nonLocalMatches.map(m =>{ - - Option(classLoader.getResource(m.group(1))).map(resource =>{ + + val classpathImports: List[File] = nonLocalMatches.flatMap { m => + Option(classLoader.getResource(m.group(1))).map(resource => { new File(resource.getFile) }) - }).flatMap(_.toList).filter(file => file.exists) + }.filter(_.exists) val importedFiles = classpathImports ++ localImports importedFiles } - + }