From 6f2ed81d043a9f3713709ad1e56ad920fc40fc3d Mon Sep 17 00:00:00 2001
From: mikesh <mikhail.shugay@gmail.com>
Date: Tue, 17 Feb 2015 19:11:28 +0300
Subject: [PATCH] Implemented loading of pre-trained classifier. Close #13

---
 .../oncomigec/pipeline/MigecCli.java          | 32 ++++++++++++++++---
 .../oncomigec/pipeline/MigecPipeline.java     | 11 ++++++-
 2 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/src/main/java/com/milaboratory/oncomigec/pipeline/MigecCli.java b/src/main/java/com/milaboratory/oncomigec/pipeline/MigecCli.java
index bb0b9b2..830bb82 100644
--- a/src/main/java/com/milaboratory/oncomigec/pipeline/MigecCli.java
+++ b/src/main/java/com/milaboratory/oncomigec/pipeline/MigecCli.java
@@ -2,6 +2,8 @@
 
 import com.milaboratory.oncomigec.core.io.misc.MigReaderParameters;
 import com.milaboratory.oncomigec.core.io.misc.UmiHistogram;
+import com.milaboratory.oncomigec.model.classifier.BaseVariantClassifier;
+import com.milaboratory.oncomigec.model.classifier.VariantClassifier;
 import org.apache.commons.cli.*;
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.FilenameUtils;
@@ -52,7 +54,8 @@ public static void main(String[] args) throws Exception {
                 OPT_IMPORT_PRESET = "import-preset", OPT_EXPORT_PRESET = "export-preset",
                 OPT_EXOME_LONG = "exome-mode", OPT_EXOME_SHORT = "E",
                 OPT_TEST_LONG = "test-mode", OPT_TEST_SHORT = "T",
-                OPT_APPEND_LONG = "append-mode", OPT_APPEND_SHORT = "A",
+                OPT_APPEND = "append-mode",
+                OPT_CLASSIFIER_FILE = "load-classifier",
                 OPT_BARCODES_LONG = "barcodes", OPT_BARCODES_SHORT = "B",
                 OPT_NO_BARCODES_LONG = "no-barcodes", OPT_NO_BARCODES_SHORT = "N",
                 OPT_REFERENCES_LONG = "references", OPT_REFERENCES_SHORT = "R",
@@ -101,6 +104,14 @@ public static void main(String[] args) throws Exception {
                                 .withDescription("output current parameter preset to the specified XML file")
                                 .withLongOpt(OPT_EXPORT_PRESET)
                                 .create()
+                )
+                .addOption(
+                        OptionBuilder
+                                .withArgName("file")
+                                .hasArg(true)
+                                .withDescription("specifies a pre-trained classifier binary file (Weka model)")
+                                .withLongOpt(OPT_CLASSIFIER_FILE)
+                                .create()
                 )
                         //
                         // modes
@@ -109,8 +120,8 @@ public static void main(String[] args) throws Exception {
                                 .hasArg(false)
                                 .withDescription("append mode, " +
                                         "will not overwrite files if specified")
-                                .withLongOpt(OPT_APPEND_LONG)
-                                .create(OPT_APPEND_SHORT)
+                                .withLongOpt(OPT_APPEND)
+                                .create()
                 )
                 .addOption(
                         OptionBuilder
@@ -207,6 +218,7 @@ public static void main(String[] args) throws Exception {
         // create the parser
         CommandLineParser parser = new BasicParser();
         MigecPipeline pipeline = null;
+        VariantClassifier variantClassifier = null;
         File outputFolder = null;
         double dumpFreq = -1;
 
@@ -249,10 +261,15 @@ public static void main(String[] args) throws Exception {
                 System.exit(0);
             }
 
+            if (commandLine.hasOption(OPT_CLASSIFIER_FILE)) {
+                File classifierFile = new File(commandLine.getOptionValue(OPT_CLASSIFIER_FILE));
+                variantClassifier = BaseVariantClassifier.pretrained(classifierFile);
+            }
+
             // mode
             if (!commandLine.hasOption(OPT_EXOME_SHORT) && !commandLine.hasOption(OPT_TEST_SHORT))
                 throw new ParseException("No mode has been set");
-            if (commandLine.hasOption(OPT_APPEND_SHORT))
+            if (commandLine.hasOption(OPT_APPEND))
                 appendMode = true;
 
             // barcodes
@@ -298,7 +315,7 @@ else if (!appendMode)
             // =================
             // Pipeline creation
             // =================
-            print1("Running MiGEC v" + MigecCli.class.getPackage().getImplementationVersion() +
+            print1("Running OncoMIGEC v" + MigecCli.class.getPackage().getImplementationVersion() +
                             " for " +
                             (paired ?
                                     (commandLine.getOptionValue(OPT_FASTQ1_SHORT) +
@@ -405,6 +422,11 @@ else if (!appendMode)
             return;
         }
 
+        if (variantClassifier != null) {
+            // user-defined classifier
+            pipeline.setVariantClassifier(variantClassifier);
+        }
+
         runSecondStage(pipeline, outputFolder);
 
         print2("Finished");
diff --git a/src/main/java/com/milaboratory/oncomigec/pipeline/MigecPipeline.java b/src/main/java/com/milaboratory/oncomigec/pipeline/MigecPipeline.java
index e9e8e66..1e80cf8 100644
--- a/src/main/java/com/milaboratory/oncomigec/pipeline/MigecPipeline.java
+++ b/src/main/java/com/milaboratory/oncomigec/pipeline/MigecPipeline.java
@@ -40,7 +40,7 @@ public class MigecPipeline {
     protected final Map<String, HaplotypeTree> haplotypeTreeBySample;
     protected final List<String> sampleNames, skippedSamples;
     protected final MigecParameterSet migecParameterSet;
-    protected final VariantClassifier variantClassifier = BaseVariantClassifier.BUILT_IN; // todo: implement loading from file
+    protected VariantClassifier variantClassifier;
 
     protected MigecPipeline(MigReader reader,
                             AssemblerFactory assemblerFactory,
@@ -60,6 +60,7 @@ protected MigecPipeline(MigReader reader,
             assemblerBySample.put(sampleName, assemblerFactory.create());
             alignerBySample.put(sampleName, consensusAlignerFactory.create());
         }
+        this.variantClassifier = BaseVariantClassifier.BUILT_IN;
     }
 
     public void skipSamples(List<String> samplesToSkip) {
@@ -216,6 +217,14 @@ public String getHaplotypeTreeFastaOutput(String sampleName) {
             return "";
     }
 
+    public VariantClassifier getVariantClassifier() {
+        return variantClassifier;
+    }
+
+    public void setVariantClassifier(VariantClassifier variantClassifier) {
+        this.variantClassifier = variantClassifier;
+    }
+
     public String getMinorVariantDump(double threshold) {
         String dump = "#SampleName\t" + Variant.HEADER;
         for (String sample : sampleNames) {