forked from dotnet/machinelearning-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
OnnxOutputParser.cs
206 lines (169 loc) · 9.23 KB
/
OnnxOutputParser.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
namespace OnnxObjectDetection
{
public class OnnxOutputParser
{
class BoundingBoxPrediction : BoundingBoxDimensions
{
public float Confidence { get; set; }
}
// The number of rows and columns in the grid the image is divided into.
public const int rowCount = 13, columnCount = 13;
// The number of features contained within a box (x, y, height, width, confidence).
public const int featuresPerBox = 5;
// Labels corresponding to the classes the onnx model can predict. For example, the
// Tiny YOLOv2 model included with this sample is trained to predict 20 different classes.
private readonly string[] classLabels;
// Predetermined anchor offsets for the bounding boxes in a cell.
private readonly (float x,float y)[] boxAnchors;
public OnnxOutputParser(IOnnxModel onnxModel)
{
classLabels = onnxModel.Labels;
boxAnchors = onnxModel.Anchors;
}
// Applies the sigmoid function that outputs a number between 0 and 1.
private float Sigmoid(float value)
{
var k = MathF.Exp(value);
return k / (1.0f + k);
}
// Normalizes an input vector into a probability distribution.
private float[] Softmax(float[] classProbabilities)
{
var max = classProbabilities.Max();
var exp = classProbabilities.Select(v => MathF.Exp(v - max));
var sum = exp.Sum();
return exp.Select(v => v / sum).ToArray();
}
// Onnx outputst a tensor that has a shape of (for Tiny YOLOv2) 125x13x13. ML.NET flattens
// this multi-dimensional into a one-dimensional array. This method allows us to access a
// specific channel for a givin (x,y) cell position by calculating the offset into the array.
private int GetOffset(int row, int column, int channel)
{
const int channelStride = rowCount * columnCount;
return (channel * channelStride) + (column * columnCount) + row;
}
// Extracts the bounding box features (x, y, height, width, confidence) method from the model
// output. The confidence value states how sure the model is that it has detected an object.
// We use the Sigmoid function to turn it that confidence into a percentage.
private BoundingBoxPrediction ExtractBoundingBoxPrediction(float[] modelOutput, int row, int column, int channel)
{
return new BoundingBoxPrediction
{
X = modelOutput[GetOffset(row, column, channel++)],
Y = modelOutput[GetOffset(row, column, channel++)],
Width = modelOutput[GetOffset(row, column, channel++)],
Height = modelOutput[GetOffset(row, column, channel++)],
Confidence = Sigmoid(modelOutput[GetOffset(row, column, channel++)])
};
}
// The predicted x and y coordinates are relative to the location of the grid cell; we use
// the logistic sigmoid to constrain these coordinates to the range 0 - 1. Then we add the
// cell coordinates (0-12) and multiply by the number of pixels per grid cell (32).
// Now x/y represent the center of the bounding box in the original 416x416 image space.
// Additionally, the size (width, height) of the bounding box is predicted relative to the
// size of an "anchor" box. So we transform the width/weight into the original 416x416 image space.
private BoundingBoxDimensions MapBoundingBoxToCell(int row, int column, int box, BoundingBoxPrediction boxDimensions)
{
const float cellWidth = ImageSettings.imageWidth / columnCount;
const float cellHeight = ImageSettings.imageHeight / rowCount;
var mappedBox = new BoundingBoxDimensions
{
X = (row + Sigmoid(boxDimensions.X)) * cellWidth,
Y = (column + Sigmoid(boxDimensions.Y)) * cellHeight,
Width = MathF.Exp(boxDimensions.Width) * cellWidth * boxAnchors[box].x,
Height = MathF.Exp(boxDimensions.Height) * cellHeight * boxAnchors[box].y,
};
// The x,y coordinates from the (mapped) bounding box prediction represent the center
// of the bounding box. We adjust them here to represent the top left corner.
mappedBox.X -= mappedBox.Width / 2;
mappedBox.Y -= mappedBox.Height / 2;
return mappedBox;
}
// Extracts the class predictions for the bounding box from the model output using the
// GetOffset method and turns them into a probability distribution using the Softmax method.
public float[] ExtractClassProbabilities(float[] modelOutput, int row, int column, int channel, float confidence)
{
var classProbabilitiesOffset = channel + featuresPerBox;
float[] classProbabilities = new float[classLabels.Length];
for (int classProbability = 0; classProbability < classLabels.Length; classProbability++)
classProbabilities[classProbability] = modelOutput[GetOffset(row, column, classProbability + classProbabilitiesOffset)];
return Softmax(classProbabilities).Select(p => p * confidence).ToArray();
}
// IoU (Intersection over union) measures the overlap between 2 boundaries. We use that to
// measure how much our predicted boundary overlaps with the ground truth (the real object
// boundary). In some datasets, we predefine an IoU threshold (say 0.5) in classifying
// whether the prediction is a true positive or a false positive. This method filters
// overlapping bounding boxes with lower probabilities.
private float IntersectionOverUnion(RectangleF boundingBoxA, RectangleF boundingBoxB)
{
var areaA = boundingBoxA.Width * boundingBoxA.Height;
var areaB = boundingBoxB.Width * boundingBoxB.Height;
if (areaA <= 0 || areaB <= 0)
return 0;
var minX = MathF.Max(boundingBoxA.Left, boundingBoxB.Left);
var minY = MathF.Max(boundingBoxA.Top, boundingBoxB.Top);
var maxX = MathF.Min(boundingBoxA.Right, boundingBoxB.Right);
var maxY = MathF.Min(boundingBoxA.Bottom, boundingBoxB.Bottom);
var intersectionArea = MathF.Max(maxY - minY, 0) * MathF.Max(maxX - minX, 0);
return intersectionArea / (areaA + areaB - intersectionArea);
}
public List<BoundingBox> ParseOutputs(float[] modelOutput, float probabilityThreshold = .3f)
{
var boxes = new List<BoundingBox>();
for (int row = 0; row < rowCount; row++)
{
for (int column = 0; column < columnCount; column++)
{
for (int box = 0; box < boxAnchors.Length; box++)
{
var channel = box * (classLabels.Length + featuresPerBox);
var boundingBoxPrediction = ExtractBoundingBoxPrediction(modelOutput, row, column, channel);
var mappedBoundingBox = MapBoundingBoxToCell(row, column, box, boundingBoxPrediction);
if (boundingBoxPrediction.Confidence < probabilityThreshold)
continue;
float[] classProbabilities = ExtractClassProbabilities(modelOutput, row, column, channel, boundingBoxPrediction.Confidence);
var (topProbability, topIndex) = classProbabilities.Select((probability, index) => (Score:probability, Index:index)).Max();
if (topProbability < probabilityThreshold)
continue;
boxes.Add(new BoundingBox
{
Dimensions = mappedBoundingBox,
Confidence = topProbability,
Label = classLabels[topIndex],
BoxColor = BoundingBox.GetColor(topIndex)
});
}
}
}
return boxes;
}
public List<BoundingBox> FilterBoundingBoxes(List<BoundingBox> boxes, int limit, float iouThreshold)
{
var results = new List<BoundingBox>();
var filteredBoxes = new bool[boxes.Count];
var sortedBoxes = boxes.OrderByDescending(b => b.Confidence).ToArray();
for (int i = 0; i < boxes.Count; i++)
{
if (filteredBoxes[i])
continue;
results.Add(sortedBoxes[i]);
if (results.Count >= limit)
break;
for (var j = i + 1; j < boxes.Count; j++)
{
if (filteredBoxes[j])
continue;
if (IntersectionOverUnion(sortedBoxes[i].Rect, sortedBoxes[j].Rect) > iouThreshold)
filteredBoxes[j] = true;
if (filteredBoxes.Count(b => b) <= 0)
break;
}
}
return results;
}
}
}