Skip to content

Commit

Permalink
Initializers callbacks constraints.. for keras
Browse files Browse the repository at this point in the history
  • Loading branch information
deepakkumar1984 committed Apr 20, 2020
1 parent 36d1f95 commit 453d689
Show file tree
Hide file tree
Showing 58 changed files with 1,659 additions and 202 deletions.
5 changes: 1 addition & 4 deletions examples/ConsoleTest/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ internal class Program
{
private static void Main(string[] args)
{
var im_fname = Utils.Download("https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/dog.jpg", "dog.jpg");

var (x_img, img) = Yolo.LoadTest(im_fname, @short: 416);
Img.ImShow(img);
var arrays = NDArray.LoadNpz(@"C:\Users\deepa\Downloads\imdb.npz");
Console.ReadLine();
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/MxNet.Keras/Backend/MxNetBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,12 @@ public static KerasSymbol GreaterEqual(KerasSymbol x, KerasSymbol y)
return new KerasSymbol(sym.BroadcastGreaterEqual(x.Symbol, y.Symbol));
}

public static KerasSymbol GreaterEqual(KerasSymbol x, float y)
{
var y_sym = sym.Full(y, x.Shape, dtype: x.DType);
return new KerasSymbol(sym.BroadcastGreaterEqual(x.Symbol, y_sym));
}

public static KerasSymbol Less(KerasSymbol x, KerasSymbol y)
{
return new KerasSymbol(sym.BroadcastLesser(x.Symbol, y.Symbol));
Expand Down
55 changes: 51 additions & 4 deletions src/MxNet.Keras/Callbacks/BaseLogger.cs
Original file line number Diff line number Diff line change
@@ -1,29 +1,76 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace MxNet.Keras.Callbacks
{
public class BaseLogger : Callback
{
public int seen;

public string[] stateful_metrics;

public Dictionary<string, float> totals;

public BaseLogger(string[] stateful_metrics = null)
{
throw new NotImplementedException();
this.stateful_metrics = stateful_metrics != null ? stateful_metrics : new string[0];
}

public override void OnEpochBegin(int epoch, Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
seen = 0;
totals = new Dictionary<string, float>();
}

public override void OnBatchEnd(int batch, Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
if (logs == null)
logs = new Dictionary<string, float>();

var batch_size = logs.ContainsKey("size") ? (int)logs["size"] : 0;
this.seen += batch_size;
foreach (var log in logs)
{
var k = log.Key;
var v = log.Value;
if (this.stateful_metrics.Contains(k))
{
this.totals[k] = v;
}
else if (this.totals.ContainsKey(k))
{
this.totals[k] += v * batch_size;
}
else
{
this.totals[k] = v * batch_size;
}
}
}

public override void OnEpochEnd(int epoch, Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
if (logs != null)
{
string[] metrics = (string[])this.@params["metrics"];
foreach (var k in metrics)
{
if (this.totals.ContainsKey(k))
{
// Make value available to next callbacks.
if (this.stateful_metrics.Contains(k))
{
logs[k] = this.totals[k];
}
else
{
logs[k] = this.totals[k] / this.seen;
}
}
}
}
}
}
}
56 changes: 52 additions & 4 deletions src/MxNet.Keras/Callbacks/CSVLogger.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,66 @@
using System;
using CsvHelper;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace MxNet.Keras.Callbacks
{
public class CSVLogger : Callback
{
public Dictionary<string, string> _open_args;

public bool append;

public bool append_header;

public FileStream csv_file;

public string file_flags;

public string filename;

public object keys;

public string sep;

public CsvWriter writer;

public CSVLogger(string filename, string separator= ",", bool append= false)
{
throw new NotImplementedException();
this.sep = separator;
this.filename = filename;
this.append = append;
this.writer = null;
this.keys = null;
this.append_header = true;
this.file_flags = "";
this._open_args = new Dictionary<string, string> {
{
"newline",
"\n"
}
};
}

public override void OnTrainBegin(Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
object mode;
if (this.append)
{
if (File.Exists(this.filename))
{
append_header = File.ReadAllLines(filename).Length > 0;
}

mode = "a";
}
else
{
mode = "w";
}

this.csv_file = File.OpenWrite(filename);
}

public override void OnEpochEnd(int epoch, Dictionary<string, float> logs = null)
Expand All @@ -23,7 +70,8 @@ public override void OnEpochEnd(int epoch, Dictionary<string, float> logs = null

public override void OnTrainEnd(Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
csv_file.Close();
writer = null;
}
}
}
13 changes: 10 additions & 3 deletions src/MxNet.Keras/Callbacks/Callback.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,26 @@ namespace MxNet.Keras.Callbacks
{
public abstract class Callback
{
public Model model;

public Dictionary<string, object> @params;

public NDArrayList validation_data;

public Callback()
{
throw new NotImplementedException();
this.validation_data = null;
this.model = null;
}

public virtual void SetParams(Dictionary<string, object> @params)
{
throw new NotImplementedException();
this.@params = @params;
}

public virtual void SetModel(Model model)
{
throw new NotImplementedException();
this.model = model;
}

public virtual void OnEpochBegin(int epoch, Dictionary<string, float> logs = null)
Expand Down
102 changes: 93 additions & 9 deletions src/MxNet.Keras/Callbacks/CallbackList.cs
Original file line number Diff line number Diff line change
@@ -1,55 +1,139 @@
using MxNet.Keras.Engine;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace MxNet.Keras.Callbacks
{
public class CallbackList
{
public double _delta_t_batch;

public List<double> _delta_ts_batch_begin;

public List<double> _delta_ts_batch_end;

public DateTime _t_enter_batch;

public List<Callback> callbacks;

public int queue_length;

public CallbackList(Callback[] callbacks = null, int queue_length = 10)
{
throw new NotImplementedException();
if (callbacks != null)
this.callbacks = callbacks.ToList();
else
this.callbacks = new List<Callback>();

this.queue_length = queue_length;
}

public void Append(Callback callback)
{
throw new NotImplementedException();
this.callbacks.Add(callback);
}

public void SetParams(Dictionary<string, object> @params)
{
throw new NotImplementedException();
foreach (var callback in this.callbacks)
{
callback.SetParams(@params);
}
}

public void SetModel(Model model)
{
throw new NotImplementedException();
foreach (var callback in this.callbacks)
{
callback.SetModel(model);
}
}

public void OnEpochBegin(int epoch, Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
if (logs == null)
logs = new Dictionary<string, float>();

foreach (var callback in this.callbacks)
{
callback.OnEpochBegin(epoch, logs);
}

_delta_t_batch = 0;
}

public void OnEpochEnd(int epoch, Dictionary<string, float> logs = null)
{
if (logs == null)
logs = new Dictionary<string, float>();

foreach (var callback in this.callbacks)
{
callback.OnEpochEnd(epoch, logs);
}
}

public void OnBatchBegin(int batch, Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
if (logs == null)
logs = new Dictionary<string, float>();

var t_before_callbacks = DateTime.Now;
foreach (var callback in this.callbacks)
{
callback.OnBatchBegin(batch, logs);
}
this._delta_ts_batch_begin.Add((DateTime.Now - t_before_callbacks).TotalMilliseconds);
var delta_t_median = _delta_ts_batch_begin.Average();
if (this._delta_t_batch > 0.0 && delta_t_median > 0.95 * this._delta_t_batch && delta_t_median > 0.1)
{
Logger.Warning($"Method on_batch_begin() is slow compared to the batch update ({delta_t_median}). Check your callbacks.");
}

this._t_enter_batch = DateTime.Now;
}

public void OnBatchEnd(int batch, Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
if (logs == null)
logs = new Dictionary<string, float>();

this._delta_t_batch = (DateTime.Now - this._t_enter_batch).TotalMilliseconds;
var t_before_callbacks = DateTime.Now;
foreach (var callback in this.callbacks)
{
callback.OnBatchEnd(batch, logs);
}
this._delta_ts_batch_end.Add((DateTime.Now - t_before_callbacks).TotalMilliseconds);
var delta_t_median = _delta_ts_batch_end.Average();
if (this._delta_t_batch > 0.0 && (delta_t_median > 0.95 * this._delta_t_batch && delta_t_median > 0.1))
{
Logger.Warning($"In your callbacks, method `on_batch_end()` is slow compared to a model step ({delta_t_median} vs {_delta_t_batch}). Check your callbacks.");
}
}

public void OnTrainBegin(Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
if (logs == null)
logs = new Dictionary<string, float>();

foreach (var callback in this.callbacks)
{
callback.OnTrainBegin(logs);
}
}

public void OnTrainEnd(Dictionary<string, float> logs = null)
{
throw new NotImplementedException();
if (logs == null)
logs = new Dictionary<string, float>();

foreach (var callback in this.callbacks)
{
callback.OnTrainEnd(logs);
}
}
}
}
Loading

0 comments on commit 453d689

Please sign in to comment.