Implementation
struct Node {
Node(const int id) : id(id), parent(this), depth(0) {};
int depth;
const int id;
Node *parent;
};
class DisjointSet {
public:
DisjointSet(const int ids) : num_ids(ids) {
init_find.reserve(ids);
for (int i = 1; i <= ids; i++)
make_set(i);
}
int find(const int x) {
if (init_find.find(x) == init_find.end()) return -1;
else return _find(x);
}
void union_(const int x, const int y) {
prev_op = "union_(" + to_string(x) + ", " + to_string(y) + ")";
int x_parent = find(x);
int y_parent = find(y);
if (init_find[y_parent]->depth <= init_find[x_parent]->depth)
init_find[y_parent]->parent = init_find[x_parent]->parent;
else if (init_find[y_parent]->depth > init_find[x_parent]->depth)
init_find[x_parent]->parent = init_find[y_parent]->parent;
// doesnt matter which, but assuming x.id < y.id I arbitrarily chose smaller one
if (init_find[y_parent]->depth == init_find[x_parent]->depth)
init_find[x_parent]->depth++;
}
void print_paths_to_root() {
cout << prev_op << endl;
for (int i = 1; i <= num_ids; i++) {
cout << setw(3) << i << ": ";
Node *iter = init_find[i];
while (iter->id != iter->parent->id) {
cout << iter->parent->id << ", ";
iter = iter->parent;
}
cout << endl;
}
cout << endl;
}
unordered_map<int, vector<int>> get_disjoint_sets() {
unordered_map<int, vector<int>> roots;
// simplify things - ensures no duplicates in the vector
for (int i = 1; i < num_ids; i++) find(i);
for (int i = 1; i <= num_ids; i++) {
vector<int> to_root;
Node *iter = init_find[i];
to_root.push_back(iter->id);
while (iter->id != iter->parent->id) {
to_root.push_back(iter->parent->id);
iter = iter->parent;
}
int root = to_root.back();
to_root.pop_back();
if (roots.find(root) != roots.end()) {
for (int i : to_root) {
roots[root].push_back(i);
}
}
else {
roots.insert({root, to_root});
}
}
return roots;
}
static void print_sets(const unordered_map<int, vector<int>> &elms) {
int set_num = 1;
for (auto i : elms) {
cout << "set " << set_num++ << ": ";
cout << i.first;
for (int j : i.second) {
cout << ", " << j;
}
cout << endl;
}
cout << endl;
}
private:
string prev_op;
const size_t num_ids;
void make_set(const int x) {
if (init_find.find(x) != init_find.end()) return;
else init_find.insert({ x, new Node(x) });
}
int _find(const int x) {
Node *iter = init_find[x]->parent;
if (iter->id == iter->parent->id) {
return iter->id;
}
else {
// path compression
while (iter->id != iter->parent->id)
iter = iter->parent;
init_find[x]->parent = init_find[iter->id];
return iter->id;
}
}
unordered_map<int, Node*> init_find;
};
int main() {
DisjointSet ds(7);
ds.print_paths_to_root(); ds.print_sets(ds.get_disjoint_sets()); //pause();
ds.union_(1, 2); ds.print_paths_to_root(); ds.print_sets(ds.get_disjoint_sets()); //pause();
ds.union_(2, 3); ds.print_paths_to_root(); ds.print_sets(ds.get_disjoint_sets()); //pause();
ds.union_(4, 5); ds.print_paths_to_root(); ds.print_sets(ds.get_disjoint_sets()); //pause();
ds.union_(6, 7); ds.print_paths_to_root(); ds.print_sets(ds.get_disjoint_sets()); //pause();
ds.union_(5, 6); ds.print_paths_to_root(); ds.print_sets(ds.get_disjoint_sets()); //pause();
ds.union_(3, 7); ds.print_paths_to_root(); ds.print_sets(ds.get_disjoint_sets()); //pause();
}Discussion
Last updated