Skip to content

Commit

Permalink
add undirected graph and coloring algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Jan 27, 2025
1 parent 28a0023 commit ea2381e
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 2 deletions.
161 changes: 159 additions & 2 deletions lib/Utils/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
#include <algorithm>
#include <cstdint>
#include <map>
#include <queue>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
Expand Down Expand Up @@ -36,9 +39,9 @@ class Graph {

// Returns true iff the given vertex has previously been added to the graph
// using `AddVertex`.
bool contains(V vertex) { return vertices.count(vertex) > 0; }
bool contains(V vertex) const { return vertices.count(vertex) > 0; }

bool empty() { return vertices.empty(); }
bool empty() const { return vertices.empty(); }

const std::set<V>& getVertices() { return vertices; }

Expand Down Expand Up @@ -154,6 +157,160 @@ class Graph {
std::map<V, std::set<V>> inEdges;
};

// An undirected graph data structure.
//
// Parameter `V` is the vertex type, which is expected to be cheap to copy.
template <typename V>
class UndirectedGraph {
public:
// Adds a vertex to the graph
void addVertex(V vertex) { vertices.insert(vertex); }

// Adds an edge between `source` and `target`. Returns false if either the
// source or target is not a previously inserted vertex, and returns true
// otherwise. The graph is unchanged if false is returned.
bool addEdge(V source, V target) {
if (vertices.count(source) == 0 || vertices.count(target) == 0) {
return false;
}
edges[source].insert(target);
edges[target].insert(source);
return true;
}

// Returns true iff the given vertex has previously been added to the graph
// using `AddVertex`.
bool contains(V vertex) const { return vertices.count(vertex) > 0; }

bool empty() const { return vertices.empty(); }

const std::set<V>& getVertices() const { return vertices; }

// Returns the edges incident to the given vertex.
std::vector<V> edgesIncidentTo(V vertex) const {
if (vertices.count(vertex)) {
std::vector<V> result(edges.at(vertex).begin(), edges.at(vertex).end());
// Note: The vertices are sorted to ensure determinism in the output.
std::sort(result.begin(), result.end());
return result;
}
return {};
}

private:
std::set<V> vertices;
std::map<V, std::set<V>> edges;
};

/// The greedy "DSatur" graph coloring algorithm
/// Cf. https://en.wikipedia.org/wiki/DSatur
template <typename V>
class GreedyGraphColoring {
public:
GreedyGraphColoring() = default;
~GreedyGraphColoring() = default;

std::unordered_map<V, int> color(const UndirectedGraph<V>& graph) {
colors.clear();
neighborColors.clear();
vertexSaturations.clear();

for (const auto& vertex : graph.getVertices()) {
neighborColors[vertex] = std::unordered_set<int>();
updateSaturationDegree(graph, vertex);
}

while (!queue.empty()) {
auto current = queue.top();
queue.pop();

// Skip if vertex is already colored or info is outdated. Could avoid
// having "stale" info by using a std::set instead of a priority queue,
// but then that would incur log(n) lookups and log(n) updates. Probably
// the extra memory usage is fine since the graphs should be relatively
// sparse.
if (colors.find(current.vertex) != colors.end() ||
current.saturationDegree !=
vertexSaturations[current.vertex].saturationDegree ||
current.uncoloredDegree !=
vertexSaturations[current.vertex].uncoloredDegree) {
continue;
}

// Use the smallest unused color among neighbors.
int color = 0;
while (neighborColors[current.vertex].find(color) !=
neighborColors[current.vertex].end()) {
color++;
}

colors[current.vertex] = color;
updateNeighborSaturation(graph, current.vertex, color);
}

return colors;
}

private:
struct VertexInfo {
V vertex;
// The number of different colors used by neighbors, primary sort key.
int saturationDegree;
// The number of uncolored neighbors, secondary sort key.
int uncoloredDegree;

bool operator<(const VertexInfo& other) const {
if (saturationDegree != other.saturationDegree)
return saturationDegree < other.saturationDegree;
if (uncoloredDegree != other.uncoloredDegree)
return uncoloredDegree < other.uncoloredDegree;
// Visit smaller index vertices first in a tiebreak
return vertex > other.vertex;
}
};

void updateSaturationDegree(const UndirectedGraph<V>& graph,
const V& vertex) {
auto neighbors = graph.edgesIncidentTo(vertex);
int uncolored = 0;
for (const auto& neighbor : neighbors) {
if (colors.find(neighbor) == colors.end()) {
uncolored++;
}
}

VertexInfo info{vertex, static_cast<int>(neighborColors[vertex].size()),
uncolored};
vertexSaturations[vertex] = info;
queue.push(info);
}

void updateNeighborSaturation(const UndirectedGraph<V>& graph,
const V& vertex, int color) {
auto neighbors = graph.edgesIncidentTo(vertex);
for (const auto& neighbor : neighbors) {
if (colors.find(neighbor) == colors.end()) {
neighborColors[neighbor].insert(color);
updateSaturationDegree(graph, neighbor);
}
}
}

// The assigned color to each vertex.
std::unordered_map<V, int> colors;

// The set of colors assigned to the neighbors of each vertex,
// to avoid looping over them when determining the next color
// to use for the current vertex.
std::unordered_map<V, std::unordered_set<int>> neighborColors;

// A mapping of vertex to its saturation data.
std::unordered_map<V, VertexInfo> vertexSaturations;

// A priority queue to find the next vertex to color.
std::priority_queue<VertexInfo> queue;
};

} // namespace graph
} // namespace heir
} // namespace mlir
Expand Down
69 changes: 69 additions & 0 deletions lib/Utils/Graph/GraphTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,75 @@ TEST(LevelSortTest, MultiOutputGraphLevelSort) {
EXPECT_THAT(levelUnwrapped[5], UnorderedElementsAre(5, 6, 7, 8, 9));
}

TEST(GraphColorTest, SimpleGraph) {
// Example graph:
// / 2 \
// 0 - 1 - 3 - 4
// \ - - - /
UndirectedGraph<int> graph;
graph.addVertex(0);
graph.addVertex(1);
graph.addVertex(2);
graph.addVertex(3);
graph.addVertex(4);
EXPECT_TRUE(graph.addEdge(0, 1));
EXPECT_TRUE(graph.addEdge(1, 2));
EXPECT_TRUE(graph.addEdge(1, 3));
EXPECT_TRUE(graph.addEdge(2, 4));
EXPECT_TRUE(graph.addEdge(3, 4));

GreedyGraphColoring<int> greedy;
std::unordered_map<int, int> colors = greedy.color(graph);
// assertions in visitation order
EXPECT_EQ(colors[1], 0);
EXPECT_EQ(colors[4], 0);
EXPECT_EQ(colors[2], 1);
EXPECT_EQ(colors[3], 1);
EXPECT_EQ(colors[0], 1);
}

TEST(GraphColorTest, CompleteGraph) {
UndirectedGraph<int> graph;
graph.addVertex(0);
graph.addVertex(1);
graph.addVertex(2);
graph.addVertex(3);
graph.addVertex(4);
EXPECT_TRUE(graph.addEdge(0, 1));
EXPECT_TRUE(graph.addEdge(0, 2));
EXPECT_TRUE(graph.addEdge(0, 3));
EXPECT_TRUE(graph.addEdge(0, 4));
EXPECT_TRUE(graph.addEdge(1, 2));
EXPECT_TRUE(graph.addEdge(1, 3));
EXPECT_TRUE(graph.addEdge(1, 4));
EXPECT_TRUE(graph.addEdge(2, 3));
EXPECT_TRUE(graph.addEdge(2, 4));
EXPECT_TRUE(graph.addEdge(3, 4));

GreedyGraphColoring<int> greedy;
std::unordered_map<int, int> colors = greedy.color(graph);
EXPECT_EQ(colors[0], 0);
EXPECT_EQ(colors[1], 1);
EXPECT_EQ(colors[2], 2);
EXPECT_EQ(colors[3], 3);
EXPECT_EQ(colors[4], 4);
}

TEST(DSATURColorTest, StarGraph) {
// Center vertex connected to 4 leaves
UndirectedGraph<int> graph;
for (int i = 0; i < 5; i++) graph.addVertex(i);
for (int i = 1; i < 5; i++) EXPECT_TRUE(graph.addEdge(0, i));

GreedyGraphColoring<int> greedy;
auto colors = greedy.color(graph);
EXPECT_EQ(colors[0], 0); // Center colored first
for (int i = 1; i < 5; i++) {
EXPECT_EQ(colors[i], 1); // All leaves same color
EXPECT_NE(colors[0], colors[i]);
}
}

} // namespace
} // namespace graph
} // namespace heir
Expand Down

0 comments on commit ea2381e

Please sign in to comment.