Implementation
Last updated
Was this helpful?
Last updated
Was this helpful?
This implementation was forked and reworked from . Addtionally, I have also compressed the implementation down to a single file. This way I can grab it whenever I need to mock something up or use it conveniently for leetcode.
#include <iostream>
#include <string>
#include <vector>
#include <unordered_map>
#include <algorithm>
#include <limits>
#include <math.h>
#include <random>
#include <unordered_set>
#include <iomanip>
#include <cassert>
using namespace std;
inline void pause() { cin.ignore(numeric_limits<streamsize>::max(), '\n'); };
template<typename T>
inline void print2d(vector< vector<T> > &stuffs) {
cout << endl;
int count = 0;
for (auto i : stuffs) {
cout << count++ << ":";
for (auto j : i)
cout << setw(2) << j << " ";
cout << endl;
}
cout << endl;
getchar();
}
template<typename T>
inline void print(vector<T> &stuffs) {
cout << endl;
for (auto i : stuffs) { cout << i << " "; }
cout << endl;
}
// node.h
class Node {
public:
static int noOfNodes;
int suffixNode;
Node() :
suffixNode(-1) {};
~Node() {
}
};
// edge.h
struct Key {
int nodeID;
int asciiChar;
};
class Edge {
public:
int startNode;
int endNode;
int startLabelIndex;
int endLabelIndex;
vector<Key> children;
Key ID;
Key parentID;
void insert();
void remove();
static long returnHashKey(int node, int c);
static Edge findEdge(int node, int c);
void setID(int node, int c);
void setParent(int node, int c);
Edge() : startNode(-1) {};
Edge(int start, int first, int last) :
startNode(start),
endNode(Node::noOfNodes++),
startLabelIndex(first),
endLabelIndex(last)
{
ID.asciiChar = -1;
ID.nodeID = -1;
parentID.nodeID = -1;
parentID.asciiChar = -1;
};
~Edge() {}
};
// suffixtree.h
struct Edge_info {
int start_node, end_node, cumLength;
Edge e;
};
class suffixTree {
public:
int rootNode;
int startIndex;
int endIndex;
suffixTree() :
rootNode(0),
startIndex(-1),
endIndex(-1) {};
suffixTree(int root, int start, int end) :
rootNode(root),
startIndex(start),
endIndex(end) {};
bool endReal() { return startIndex > endIndex; }
bool endImg() { return endIndex >= startIndex; }
void migrateToClosestParent();
};
// edge.cpp
long Edge::returnHashKey(int nodeID, int c) {
return (nodeID + (((long)c) << 59));
}
void Edge::setID(int node, int c) {
this->ID.asciiChar = c;
this->ID.nodeID = node;
}
void Edge::setParent(int node, int c) {
this->parentID.asciiChar = c;
this->parentID.nodeID = node;
}
// suffixtree.cpp
string Input;
int inputLength;
int Node::noOfNodes = 1;
Node *nodeArray;
unordered_map<long, Edge> edgeHash;
void printAllEdges() {
int count = 0;
cout << "StartNode\tEndNode\tParentStart\tParentEnd" << endl;
for (auto it = edgeHash.begin(); it != edgeHash.end(); it++) {
Edge parent = Edge::findEdge(it->second.parentID.nodeID, it->second.parentID.asciiChar);
cout << it->second.startNode << "\t\t" << it->second.endNode
// << "\t\t" << nodeArray[it->second.endNode].suffixNode
// << "\t\t" << it->second.startLabelIndex
// << "\t\t" << it->second.endLabelIndex
<< "\t\t" << parent.startNode
<< "\t\t" << parent.endNode << " ";
count++;
int head;
if (inputLength > it->second.endLabelIndex)
head = it->second.endLabelIndex;
else
head = inputLength;
for (int i = it->second.startLabelIndex; i < head + 1; i++)
cout << Input[i];
cout << endl;
}
cout << "Total edges: " << count << endl;
for (auto it = edgeHash.begin(); it != edgeHash.end(); it++) {
Edge parent = Edge::findEdge(it->second.parentID.nodeID, it->second.parentID.asciiChar);
cout << "Parent is: " << parent.startNode << " " << parent.endNode << endl;
for (Key childID : parent.children) {
Edge child = Edge::findEdge(childID.nodeID, childID.asciiChar);
cout << child.startNode << "\t\t" << child.endNode << " ";
}
cout << endl;
}
}
void Edge::insert() {
long key = returnHashKey(startNode, Input[startLabelIndex]);
edgeHash[key] = *this;
}
void Edge::remove() {
long key = returnHashKey(startNode, Input[startLabelIndex]);
edgeHash.erase(key);
}
Edge Edge::findEdge(int node, int asciiChar) {
long key = returnHashKey(node, asciiChar);
unordered_map<long, Edge>::const_iterator search = edgeHash.find(key);
if (search != edgeHash.end()) {
return edgeHash.at(key);
}
return Edge();
}
void suffixTree::migrateToClosestParent() {
if (endReal()) {
}
else {
Edge e = Edge::findEdge(rootNode, Input[startIndex]);
if (e.startNode == -1) {
cout << rootNode << " " << startIndex << " " << Input[startIndex] << endl;
}
assert(e.startNode != -1);
int labelLength = e.endLabelIndex - e.startLabelIndex;
while (labelLength <= (endIndex - startIndex)) {
startIndex += labelLength + 1;
rootNode = e.endNode;
if (startIndex <= endIndex) {
e = Edge::findEdge(e.endNode, Input[startIndex]);
if (e.startNode == -1) {
cout << rootNode << " " << startIndex << " " << Input[startIndex] << endl;
}
assert(e.startNode != -1);
labelLength = e.endLabelIndex - e.startLabelIndex;
}
}
}
}
int breakEdge(const suffixTree &s, Edge &e) {
e.remove();
Edge *newEdge = new Edge(s.rootNode, e.startLabelIndex,
e.startLabelIndex + s.endIndex - s.startIndex);
nodeArray[newEdge->endNode].suffixNode = s.rootNode;
e.startLabelIndex += s.endIndex - s.startIndex + 1;
e.startNode = newEdge->endNode;
e.setID(e.startNode, Input[e.startLabelIndex]);
newEdge->children.push_back(e.ID);
newEdge->setID(newEdge->startNode, Input[newEdge->startLabelIndex]);
if (newEdge->startNode > 0) {
newEdge->setParent(nodeArray[newEdge->startNode].suffixNode, Input[s.startIndex - 1]);
}
newEdge->insert();
for (Key id : e.children) {
Edge child = Edge::findEdge(id.nodeID, id.asciiChar);
assert(child.ID.nodeID != -1);
child.setParent(e.ID.nodeID, e.ID.asciiChar);
child.insert();
}
e.setParent(newEdge->ID.nodeID, newEdge->ID.asciiChar);
e.insert();
return newEdge->endNode;
}
void carryPhase(suffixTree &tree, int lastIndex) {
int parentNode;
int previousParentNode = -1;
while (true) {
Edge e;
parentNode = tree.rootNode;
if (tree.endReal()) {
e = Edge::findEdge(tree.rootNode, Input[lastIndex]);
if (e.startNode != -1)
break;
}
else {
e = Edge::findEdge(tree.rootNode, Input[tree.startIndex]);
int diff = tree.endIndex - tree.startIndex;
if (Input[e.startLabelIndex + diff + 1] == Input[lastIndex])
break;
parentNode = breakEdge(tree, e);
}
if (previousParentNode > 0) {
nodeArray[previousParentNode].suffixNode = parentNode;
}
Edge *newEdge = new Edge(parentNode, lastIndex, inputLength);
newEdge->setID(newEdge->startNode, Input[newEdge->startLabelIndex]);
if (e.startNode == -1 && newEdge->startNode > 0) {
Edge parent = Edge::findEdge(nodeArray[newEdge->startNode].suffixNode, Input[tree.startIndex - 1]);
if (parent.startNode != -1) {
newEdge->setParent(parent.ID.nodeID, parent.ID.asciiChar);
parent.children.push_back(newEdge->ID);
parent.insert();
}
}
if (e.startNode != -1 && newEdge->startNode > 0) {
Edge parent = Edge::findEdge(e.parentID.nodeID, e.parentID.asciiChar);
newEdge->setParent(parent.ID.nodeID, parent.ID.asciiChar);
parent.children.push_back(newEdge->ID);
parent.insert();
}
newEdge->insert();
previousParentNode = parentNode;
if (tree.rootNode == 0)
tree.startIndex++;
else {
tree.rootNode = nodeArray[tree.rootNode].suffixNode;
}
tree.migrateToClosestParent();
}
if (previousParentNode > 0)
nodeArray[previousParentNode].suffixNode = parentNode;
tree.endIndex++;
tree.migrateToClosestParent();
}
bool search(string pattern) {
int len = pattern.length();
Edge e = Edge::findEdge(0, pattern[0]);
int iter = 0;
int i = -1;
if (e.startNode != -1) {
while (i < len) {
iter = 0;
while (e.endLabelIndex >= e.startLabelIndex + iter) {
if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
iter++;
if (i + iter + 1 >= len) {
return true;
}
}
else {
return false;
}
}
assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
if (e.startNode == -1) {
return false;
}
i += (iter);
}
}
return false;
}
int getIndex(string pattern) {
int len = pattern.length();
Edge e = Edge::findEdge(0, pattern[0]);
int iter = 0;
int i = -1;
if (e.startNode != -1) {
while (i < len) {
iter = 0;
while (e.endLabelIndex >= e.startLabelIndex + iter) {
if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
iter++;
if (i + iter + 1 >= len) {
return e.startLabelIndex + iter - 1;
}
}
else {
return -1;
}
}
assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
if (e.startNode == -1) {
return -1;
}
i += (iter);
}
}
return -1;
}
// reset the data structure for next applications
void reset() {
if (nodeArray) {
delete nodeArray;
edgeHash.clear();
Node::noOfNodes = 0;
Input.clear();
}
}
// identifies the first starting edge to the pattern of the tree
Edge_info getStartEdge(const string &pattern) {
int len = pattern.length();
Edge e = Edge::findEdge(0, pattern[0]);
Edge_info start;
int iter = 0;
int i = -1;
int cumLength = 0;
cumLength += abs(e.startLabelIndex - e.endLabelIndex) + 1;
if (e.startNode != -1) {
while (i < len) {
iter = 0;
while (e.endLabelIndex >= e.startLabelIndex + iter) {
if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
iter++;
if (i + iter + 1 >= len) {
start.start_node = e.startNode;
start.end_node = e.endNode;
start.cumLength = cumLength;
start.e = e;
return start;
}
}
}
assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
cumLength += abs(e.startLabelIndex - e.endLabelIndex) + 1;
if (e.startNode == -1) {
start.start_node = e.startNode;
start.end_node = e.endNode;
start.cumLength = cumLength;
start.e = e;
return start;
}
i += (iter);
}
}
start.start_node = e.startNode;
start.end_node = e.endNode;
start.cumLength = cumLength;
start.e = e;
return start;
}
Sanity Testing for two applications: (1) Exact Matching (2) Smallest Lexicographical Rotation
// exactMatching.cpp
void DFS_subtree(Edge curEdge, int cumLength, std::vector<int> &occurrence, int textLength) {
for (Key childID : curEdge.children) {
Edge child = Edge::findEdge(childID.nodeID, childID.asciiChar);
int newCumLength =
cumLength +
abs(child.startLabelIndex - child.endLabelIndex) + 1;
DFS_subtree(child, newCumLength, occurrence, textLength);
if (child.children.empty()) {
occurrence.push_back(((textLength - 1) - newCumLength) + 1);
}
}
}
vector<int> exactStringMatching(const string text, const string pattern) {
// reset data structures
reset();
Input = text + '$';
inputLength = Input.length() - 1;
nodeArray = reinterpret_cast<Node *> (malloc(2 * inputLength * sizeof(Node)));
suffixTree tree(0, 0, -1);
for (int i = 0; i <= inputLength; i++) {
carryPhase(tree, i);
}
std::vector<int> occurrences;
// found - iterate to first occurrence
Edge_info start = getStartEdge(pattern);
Edge parent = Edge::findEdge(start.e.parentID.nodeID, start.e.parentID.asciiChar);
// edgecase: occurance is immediately at leaf branch
if (start.e.children.empty()) {
if (start.start_node == 0) {
occurrences.push_back(start.e.startLabelIndex);
}
else if (start.e.children.empty() && start.start_node > 0) {
int found = getIndex(pattern) - pattern.length() + 1;
if (found != -1)
occurrences.push_back(found);
}
}
// check if otherwise exists: then DFS through subtree
else if (start.start_node != -1) {
DFS_subtree(start.e, start.cumLength, occurrences, Input.length());
}
return occurrences;
}
// lexicoSmallest.cpp
pair<int, int> linearTraversalLex(Edge minChild, const size_t N, int cumStart) {
std::pair<int, int> minLexIndex;
int endIndex = cumStart - 1;
while (cumStart < N) {
// find next min edge
char minChar = std::numeric_limits<char>::max();
int minNodeId;
for (Key childID : minChild.children) {
if (childID.asciiChar < minChar && childID.asciiChar != '$') {
minChar = childID.asciiChar;
minNodeId = childID.nodeID;
}
}
minChild = Edge::findEdge(minNodeId, minChar);
cumStart += abs(minChild.startLabelIndex - minChild.endLabelIndex) + 1;
endIndex = minChild.endLabelIndex;
}
int overflow = abs(cumStart - int(N));
minLexIndex.second = abs(endIndex - overflow);
minLexIndex.first = minLexIndex.second - N + 1;
return minLexIndex;
}
string smallestLexicographicRotation(const string &s) {
// reset data structures
reset();
Input = s + s + std::numeric_limits<char>::max();
inputLength = Input.length() - 1;
nodeArray = reinterpret_cast<Node *>(malloc(2 * inputLength * (sizeof(Node))));
suffixTree tree(0, 0, -1);
for (int i = 0; i <= inputLength; i++) {
carryPhase(tree, i);
}
char min_char = std::numeric_limits<char>::max();
for (char c : s) {
if (c < min_char)
min_char = c;
}
Edge root = Edge::findEdge(0, min_char);
int cumStart = abs(root.startLabelIndex - root.endLabelIndex) + 1;
std::pair<int, int> startEnd = linearTraversalLex(root, s.size(), cumStart);
return Input.substr(startEnd.first, startEnd.second - startEnd.first + 1);
}
// sanity checking - exactMatching
#define NUM_TEST_CASES_EXT1 200
#define RANDOM_STRING_SIZE_EXT1 20
std::vector<int> bruteForceOccurances(const string text, const string pattern) {
std::vector<int> occurances;
if (pattern.empty()) return occurances;
const size_t pLen = pattern.length();
const int go_to = text.size() - pLen;
for (int i = 0; i <= go_to; i++) {
string temp = text.substr(i, pLen);
if (text.substr(i, pLen) == pattern) occurances.push_back(i);
}
return occurances;
}
void exactMatchingSanityCheck() {
cout << "exactMatchingSanityCheck Starting" << endl;
const std::string Alphabet = "abcdefghijklmnopqrstuvwxyz";
std::default_random_engine seed;
std::uniform_int_distribution<int> alphabetGenerator(0, 25);
inputLength = RANDOM_STRING_SIZE_EXT1;
// generate random Text
Input.clear();
for (int i = 0; i < NUM_TEST_CASES_EXT1; i++) {
string test;
for (int i = 0; i < RANDOM_STRING_SIZE_EXT1; i++)
test.push_back('a' + alphabetGenerator(seed));
/*Generate a pattern*/
std::uniform_int_distribution<int> startGenerator(0, RANDOM_STRING_SIZE_EXT1 - 1);
int start = startGenerator(seed);
std::uniform_int_distribution<int> lenGenerator(0, RANDOM_STRING_SIZE_EXT1 - start - 1);
int end = lenGenerator(seed);
// make sure that our pattern is not longer than our input string
assert((start + end) < test.length());
std::string P = test.substr(start, end);
// testing results
std::vector<int> occurancesBF = bruteForceOccurances(test, P);
std::vector<int> occurancesUK = exactStringMatching(test, P);
sort(occurancesUK.begin(), occurancesUK.end());
assert(occurancesUK.size() == occurancesBF.size());
for (int i = 0; i < occurancesBF.size(); i++) {
assert(occurancesBF.at(i) == occurancesUK.at(i));
}
// reset data structures
Input.clear();
test.clear();
occurancesBF.clear();
occurancesUK.clear();
cout << i << " ";
}
cout << endl;
cout << "exactMatchingSanityCheck Complete" << endl << endl;
}
// sanity checking - lexicoSmallest
#define NUM_TEST_CASES_EXT2 1000
#define RANDOM_STRING_SIZE_EXT2 20
string brute_force_min_rotation(const string &text) {
const std::string doubleText = text + text;
const size_t goTo = doubleText.length() - text.length();
std::string minRotation = doubleText.substr(0, text.length());
string temp;
for (int i = 1; i <= goTo; i++) {
temp = doubleText.substr(i, text.length());
if (temp < minRotation) {
minRotation = temp;
}
}
return minRotation;
}
void lexicoSmallestSanityTest() {
cout << "lexicoSmallestSanityTest Starting" << endl;
const std::string Alphabet = "abcdefghijklmnopqrstuvwxyz";
std::default_random_engine seed;
std::uniform_int_distribution<int> alphabetGenerator(0, 25);
inputLength = RANDOM_STRING_SIZE_EXT2;
string testStr;
for (int i = 0; i < NUM_TEST_CASES_EXT2; i++) {
testStr.clear();
// generate random Text
for (int i = 0; i < RANDOM_STRING_SIZE_EXT2; i++)
testStr.push_back('a' + alphabetGenerator(seed));
// testing results
std::string bf_result = brute_force_min_rotation(testStr);
std::string uk_result = smallestLexicographicRotation(testStr);
// expect the same min lex string
assert(bf_result == uk_result);
cout << i << " ";
}
cout << endl;
cout << "lexicoSmallestSanityTest Complete" << endl << endl;
}
int main(int argc, char** argv)
{
exactMatchingSanityCheck();
lexicoSmallestSanityTest();
}
Generalized Suffix Tree: Return the occurrences found within a set of delimited strings.
vector<int> generalizedExactStringMatching(const vector<string> texts, const string pattern) {
reset();
const char unique_delim = '$';
for (int i = 0; i < texts.size(); i++)
Input += texts.at(i) + (unique_delim + to_string(i));
inputLength = Input.length() - 1;
nodeArray = reinterpret_cast<Node *> (malloc(2 * inputLength * sizeof(Node)));
suffixTree tree(0, 0, -1);
for (int i = 0; i <= inputLength; i++) {
carryPhase(tree, i);
}
std::vector<int> occurrences;
// found - iterate to first occurrence
Edge_info start = getStartEdge(pattern);
Edge parent = Edge::findEdge(start.e.parentID.nodeID, start.e.parentID.asciiChar);
// edgecase: occurance is immediately at leaf branch
if (start.e.children.empty()) {
if (start.start_node == 0) {
occurrences.push_back(start.e.startLabelIndex);
}
else if (start.e.children.empty() && start.start_node > 0) {
int found = getIndex(pattern) - pattern.length() + 1;
if (found != -1)
occurrences.push_back(found);
}
}
// check if otherwise exists: then DFS through subtree
else if (start.start_node != -1) {
DFS_subtree(start.e, start.cumLength, occurrences, Input.length());
}
return occurrences;
}
int main(int argc, char** argv)
{
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// b a n a n a $ 0 b a t m a n $ 1
vector<int> occurances = generalizedExactStringMatching({"banana", "batman"}, "an");
print(occurances); // 1 3 12
}