-
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initializers callbacks constraints.. for keras
- Loading branch information
1 parent
36d1f95
commit 453d689
Showing
58 changed files
with
1,659 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,76 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
|
||
namespace MxNet.Keras.Callbacks | ||
{ | ||
public class BaseLogger : Callback | ||
{ | ||
public int seen; | ||
|
||
public string[] stateful_metrics; | ||
|
||
public Dictionary<string, float> totals; | ||
|
||
public BaseLogger(string[] stateful_metrics = null) | ||
{ | ||
throw new NotImplementedException(); | ||
this.stateful_metrics = stateful_metrics != null ? stateful_metrics : new string[0]; | ||
} | ||
|
||
public override void OnEpochBegin(int epoch, Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
seen = 0; | ||
totals = new Dictionary<string, float>(); | ||
} | ||
|
||
public override void OnBatchEnd(int batch, Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
if (logs == null) | ||
logs = new Dictionary<string, float>(); | ||
|
||
var batch_size = logs.ContainsKey("size") ? (int)logs["size"] : 0; | ||
this.seen += batch_size; | ||
foreach (var log in logs) | ||
{ | ||
var k = log.Key; | ||
var v = log.Value; | ||
if (this.stateful_metrics.Contains(k)) | ||
{ | ||
this.totals[k] = v; | ||
} | ||
else if (this.totals.ContainsKey(k)) | ||
{ | ||
this.totals[k] += v * batch_size; | ||
} | ||
else | ||
{ | ||
this.totals[k] = v * batch_size; | ||
} | ||
} | ||
} | ||
|
||
public override void OnEpochEnd(int epoch, Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
if (logs != null) | ||
{ | ||
string[] metrics = (string[])this.@params["metrics"]; | ||
foreach (var k in metrics) | ||
{ | ||
if (this.totals.ContainsKey(k)) | ||
{ | ||
// Make value available to next callbacks. | ||
if (this.stateful_metrics.Contains(k)) | ||
{ | ||
logs[k] = this.totals[k]; | ||
} | ||
else | ||
{ | ||
logs[k] = this.totals[k] / this.seen; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,139 @@ | ||
using MxNet.Keras.Engine; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
|
||
namespace MxNet.Keras.Callbacks | ||
{ | ||
public class CallbackList | ||
{ | ||
public double _delta_t_batch; | ||
|
||
public List<double> _delta_ts_batch_begin; | ||
|
||
public List<double> _delta_ts_batch_end; | ||
|
||
public DateTime _t_enter_batch; | ||
|
||
public List<Callback> callbacks; | ||
|
||
public int queue_length; | ||
|
||
public CallbackList(Callback[] callbacks = null, int queue_length = 10) | ||
{ | ||
throw new NotImplementedException(); | ||
if (callbacks != null) | ||
this.callbacks = callbacks.ToList(); | ||
else | ||
this.callbacks = new List<Callback>(); | ||
|
||
this.queue_length = queue_length; | ||
} | ||
|
||
public void Append(Callback callback) | ||
{ | ||
throw new NotImplementedException(); | ||
this.callbacks.Add(callback); | ||
} | ||
|
||
public void SetParams(Dictionary<string, object> @params) | ||
{ | ||
throw new NotImplementedException(); | ||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.SetParams(@params); | ||
} | ||
} | ||
|
||
public void SetModel(Model model) | ||
{ | ||
throw new NotImplementedException(); | ||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.SetModel(model); | ||
} | ||
} | ||
|
||
public void OnEpochBegin(int epoch, Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
if (logs == null) | ||
logs = new Dictionary<string, float>(); | ||
|
||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.OnEpochBegin(epoch, logs); | ||
} | ||
|
||
_delta_t_batch = 0; | ||
} | ||
|
||
public void OnEpochEnd(int epoch, Dictionary<string, float> logs = null) | ||
{ | ||
if (logs == null) | ||
logs = new Dictionary<string, float>(); | ||
|
||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.OnEpochEnd(epoch, logs); | ||
} | ||
} | ||
|
||
public void OnBatchBegin(int batch, Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
if (logs == null) | ||
logs = new Dictionary<string, float>(); | ||
|
||
var t_before_callbacks = DateTime.Now; | ||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.OnBatchBegin(batch, logs); | ||
} | ||
this._delta_ts_batch_begin.Add((DateTime.Now - t_before_callbacks).TotalMilliseconds); | ||
var delta_t_median = _delta_ts_batch_begin.Average(); | ||
if (this._delta_t_batch > 0.0 && delta_t_median > 0.95 * this._delta_t_batch && delta_t_median > 0.1) | ||
{ | ||
Logger.Warning($"Method on_batch_begin() is slow compared to the batch update ({delta_t_median}). Check your callbacks."); | ||
} | ||
|
||
this._t_enter_batch = DateTime.Now; | ||
} | ||
|
||
public void OnBatchEnd(int batch, Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
if (logs == null) | ||
logs = new Dictionary<string, float>(); | ||
|
||
this._delta_t_batch = (DateTime.Now - this._t_enter_batch).TotalMilliseconds; | ||
var t_before_callbacks = DateTime.Now; | ||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.OnBatchEnd(batch, logs); | ||
} | ||
this._delta_ts_batch_end.Add((DateTime.Now - t_before_callbacks).TotalMilliseconds); | ||
var delta_t_median = _delta_ts_batch_end.Average(); | ||
if (this._delta_t_batch > 0.0 && (delta_t_median > 0.95 * this._delta_t_batch && delta_t_median > 0.1)) | ||
{ | ||
Logger.Warning($"In your callbacks, method `on_batch_end()` is slow compared to a model step ({delta_t_median} vs {_delta_t_batch}). Check your callbacks."); | ||
} | ||
} | ||
|
||
public void OnTrainBegin(Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
if (logs == null) | ||
logs = new Dictionary<string, float>(); | ||
|
||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.OnTrainBegin(logs); | ||
} | ||
} | ||
|
||
public void OnTrainEnd(Dictionary<string, float> logs = null) | ||
{ | ||
throw new NotImplementedException(); | ||
if (logs == null) | ||
logs = new Dictionary<string, float>(); | ||
|
||
foreach (var callback in this.callbacks) | ||
{ | ||
callback.OnTrainEnd(logs); | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.