Skip to content

Commit 12f036e

Browse files
committed
Change 'NeuralNetwork.Step' to only use the current step's cached outputs and an option to always reevaluate self-referencing nodes
1 parent c7f4b8b commit 12f036e

File tree

3 files changed

+34
-54
lines changed

3 files changed

+34
-54
lines changed

MaceEvolve/Controls/GameHost.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ public int TargetTPS
5353
public Rectangle WorldBounds { get; set; }
5454
public Rectangle SuccessBounds { get; set; }
5555
public int MinCreatureConnections { get; set; } = 4;
56-
public int MaxCreatureConnections { get; set; } = 64;
56+
public int MaxCreatureConnections { get; set; } = 128;
5757
public double CreatureSpeed { get; set; }
5858
public double NewGenerationInterval { get; set; } = 12;
5959
public double SecondsUntilNewGeneration { get; set; } = 12;
60-
public int MaxCreatureProcessNodes { get; set; } = 3;
60+
public int MaxCreatureProcessNodes { get; set; } = 5;
6161
public double MutationChance { get; set; } = 0.1;
6262
public double MutationAttempts { get; set; } = 10;
6363
public double ConnectionWeightBound { get; set; } = 4;

MaceEvolve/Models/Creature.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public void Live()
104104
UpdateCurrentStepInfo();
105105
UpdateInputValues();
106106

107-
Dictionary<int, double> NodeIdToOutputDict = Brain.LoggedStep(true, false);
107+
Dictionary<int, double> NodeIdToOutputDict = Brain.LoggedStep(true, true);
108108
Dictionary<Node, double> NodeOutputsDict = NodeIdToOutputDict.OrderBy(x => x.Value).ToDictionary(x => Brain.Nodes[x.Key], x => x.Value);
109109
Node HighestOutputNode = NodeOutputsDict.Keys.LastOrDefault(x => x.NodeType == NodeType.Output);
110110

MaceEvolve/Models/NeuralNetwork.cs

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ public class NeuralNetwork
2121
[JsonIgnore]
2222
public List<Node> Nodes { get; } = new List<Node>();
2323
public List<Connection> Connections { get; set; } = new List<Connection>();
24-
25-
public Dictionary<Node, double> PreviousNodeOutputs { get; set; } = new Dictionary<Node, double>();
2624
public List<NeuralNetworkStepInfo> PreviousStepInfo { get; set; } = new List<NeuralNetworkStepInfo>();
2725
#endregion
2826

@@ -414,7 +412,7 @@ public static IEnumerable<Node> GetPossibleTargetNodes(IEnumerable<Node> Nodes)
414412
{
415413
return Nodes.Where(x => x.NodeType == NodeType.Process || x.NodeType == NodeType.Output);
416414
}
417-
public Dictionary<int, double> Step(bool OutputNodesOnly, bool CacheSelfReferencingConnections)
415+
public Dictionary<int, double> Step(bool OutputNodesOnly, bool AlwaysReevaluateNodesWithSelfReferencingConnections)
418416
{
419417
if (OutputNodesOnly)
420418
{
@@ -477,52 +475,43 @@ public Dictionary<int, double> Step(bool OutputNodesOnly, bool CacheSelfReferenc
477475
double SourceNodeOutput;
478476
Node ConnectionSourceNode = IdToNodeDict[Connection.SourceId];
479477

480-
if (NodesBeingEvaluated.Contains(Connection.SourceId) || (CurrentNodeHasSelfReferencingConnections && !CacheSelfReferencingConnections))
478+
//If the source node is already being evaluated, meaning either the current connection is a circular reference or the source node is present earlier in the queue and is a circular reference,
479+
//The cached output of the source node must be used. If there is no cached value, initialise one with a value of 0.
480+
//OR
481+
//Whether the node is evaluated or not, if there is a self referencing connection, use the specified parameter to determine whether it should be evaluated again or not.
482+
//This is important because after a self referencing node's output is calculated, it is cached. When getting the value of that node again, something needs to decided whether to use the original output
483+
//or to calculate a new output using the cached output to resolve the circular reference instead of the initial value of 0.
484+
if (NodesBeingEvaluated.Contains(Connection.SourceId) || !(CurrentNodeHasSelfReferencingConnections && !AlwaysReevaluateNodesWithSelfReferencingConnections))
481485
{
482-
if (PreviousNodeOutputs.TryGetValue(ConnectionSourceNode, out double PreviousSourceNodeOutput))
486+
if (CachedNodeOutputs.TryGetValue(Connection.SourceId, out double CachedSourceNodeOutput))
483487
{
484-
SourceNodeOutput = PreviousSourceNodeOutput;
488+
SourceNodeOutput = CachedSourceNodeOutput;
485489
}
486490
else
487491
{
488-
PreviousNodeOutputs[ConnectionSourceNode] = 0;
489-
SourceNodeOutput = PreviousNodeOutputs[ConnectionSourceNode];
492+
CachedNodeOutputs[Connection.SourceId] = 0;
493+
SourceNodeOutput = CachedNodeOutputs[Connection.SourceId];
490494
}
491495
}
492496
else
493497
{
494-
if (CachedNodeOutputs.TryGetValue(Connection.SourceId, out double CachedSourceNodeOutput))
495-
{
496-
SourceNodeOutput = CachedSourceNodeOutput;
497-
}
498-
else
499-
{
500-
NodeQueue.Add(Connection.SourceId);
501-
CurrentNodeWeightedSum = null;
502-
break;
503-
}
498+
NodeQueue.Add(Connection.SourceId);
499+
CurrentNodeWeightedSum = null;
500+
break;
504501
}
505502

506-
//CurrentNodeWeightedSum ??= 0;
507-
508503
CurrentNodeWeightedSum += SourceNodeOutput * Connection.Weight;
509504
}
510505
}
511506
}
512507

513508
if (CurrentNodeWeightedSum != null)
514509
{
515-
double CurrentNodeOutput = Globals.ReLU(CurrentNodeWeightedSum.Value + CurrentNode.Bias);
516-
517-
if (CurrentNodeOutput > 10000)
518-
{
519-
var thing = 2;
520-
}
510+
double CurrentNodeOutput = CurrentNode.NodeType == NodeType.Input ? CurrentNodeWeightedSum.Value : Globals.ReLU(CurrentNodeWeightedSum.Value + CurrentNode.Bias);
521511

522512
NodesBeingEvaluated.Remove(CurrentNodeId);
523513

524514
CachedNodeOutputs[CurrentNodeId] = CurrentNodeOutput;
525-
PreviousNodeOutputs[CurrentNode] = CurrentNodeOutput;
526515

527516
NodeQueue.Remove(CurrentNodeId);
528517
}
@@ -535,7 +524,7 @@ public Dictionary<int, double> Step(bool OutputNodesOnly, bool CacheSelfReferenc
535524
throw new NotImplementedException();
536525
}
537526
}
538-
public Dictionary<int, double> LoggedStep(bool OutputNodesOnly, bool CacheSelfReferencingConnections)
527+
public Dictionary<int, double> LoggedStep(bool OutputNodesOnly, bool AlwaysReevaluateNodesWithSelfReferencingConnections)
539528
{
540529
Dictionary<Node, int> NodeToIdDict = Nodes.ToDictionary(x => x, x => GetNodeId(x));
541530

@@ -631,34 +620,31 @@ public Dictionary<int, double> LoggedStep(bool OutputNodesOnly, bool CacheSelfRe
631620
double SourceNodeOutput;
632621
Node ConnectionSourceNode = IdToNodeDict[Connection.SourceId];
633622

634-
if (NodesBeingEvaluated.Contains(Connection.SourceId) || (CurrentNodeHasSelfReferencingConnections && !CacheSelfReferencingConnections))
623+
//If the source node is already being evaluated, meaning either the current connection is a circular reference or the source node is present earlier in the queue and is a circular reference,
624+
//The cached output of the source node must be used. If there is no cached value, initialise one with a value of 0.
625+
//OR
626+
//Whether the node is evaluated or not, if there is a self referencing connection, use the specified parameter to determine whether it should be evaluated again or not.
627+
//This is important because after a self referencing node's output is calculated, it is cached. When getting the value of that node again, something needs to decided whether to use the original output
628+
//or to calculate a new output using the cached output to resolve the circular reference instead of the initial value of 0.
629+
if (NodesBeingEvaluated.Contains(Connection.SourceId) || !(CurrentNodeHasSelfReferencingConnections && !AlwaysReevaluateNodesWithSelfReferencingConnections))
635630
{
636-
if (PreviousNodeOutputs.TryGetValue(ConnectionSourceNode, out double PreviousSourceNodeOutput))
631+
if (CachedNodeOutputs.TryGetValue(Connection.SourceId, out double CachedSourceNodeOutput))
637632
{
638-
SourceNodeOutput = PreviousSourceNodeOutput;
633+
SourceNodeOutput = CachedSourceNodeOutput;
639634
}
640635
else
641636
{
642-
PreviousNodeOutputs[ConnectionSourceNode] = 0;
643-
SourceNodeOutput = PreviousNodeOutputs[ConnectionSourceNode];
637+
CachedNodeOutputs[Connection.SourceId] = 0;
638+
SourceNodeOutput = CachedNodeOutputs[Connection.SourceId];
644639
}
645640
}
646641
else
647642
{
648-
if (CachedNodeOutputs.TryGetValue(Connection.SourceId, out double CachedSourceNodeOutput))
649-
{
650-
SourceNodeOutput = CachedSourceNodeOutput;
651-
}
652-
else
653-
{
654-
NodeQueue.Add(Connection.SourceId);
655-
CurrentNodeWeightedSum = null;
656-
break;
657-
}
643+
NodeQueue.Add(Connection.SourceId);
644+
CurrentNodeWeightedSum = null;
645+
break;
658646
}
659647

660-
//CurrentNodeWeightedSum ??= 0;
661-
662648
CurrentNodeWeightedSum += SourceNodeOutput * Connection.Weight;
663649
}
664650
}
@@ -668,15 +654,9 @@ public Dictionary<int, double> LoggedStep(bool OutputNodesOnly, bool CacheSelfRe
668654
{
669655
double CurrentNodeOutput = CurrentNode.NodeType == NodeType.Input ? CurrentNodeWeightedSum.Value : Globals.ReLU(CurrentNodeWeightedSum.Value + CurrentNode.Bias);
670656

671-
if (CurrentNodeOutput > 10000)
672-
{
673-
var thing = 2;
674-
}
675-
676657
NodesBeingEvaluated.Remove(CurrentNodeId);
677658

678659
CachedNodeOutputs[CurrentNodeId] = CurrentNodeOutput;
679-
PreviousNodeOutputs[CurrentNode] = CurrentNodeOutput;
680660
CurrentNodeStepInfo.PreviousOutput = CurrentNodeOutput;
681661

682662
NodeQueue.Remove(CurrentNodeId);

0 commit comments

Comments
 (0)