Implementation

This implementation was forked and reworked from https://github.com/atuljangra/Ukkonen-SuffixTree. Addtionally, I have also compressed the implementation down to a single file. This way I can grab it whenever I need to mock something up or use it conveniently for leetcode.

#include <iostream>
#include <string>
#include <vector>
#include <unordered_map>
#include <algorithm>
#include <limits>
#include <math.h>
#include <random>
#include <unordered_set>
#include <iomanip>
#include <cassert>


using namespace std;

inline void pause() { cin.ignore(numeric_limits<streamsize>::max(), '\n'); };

template<typename T>
inline void print2d(vector< vector<T> > &stuffs) {
    cout << endl;
    int count = 0;
    for (auto i : stuffs) {
        cout << count++ << ":";
        for (auto j : i)
            cout << setw(2) << j << " ";
        cout << endl;
    }
    cout << endl;
    getchar();
}

template<typename T>
inline void print(vector<T> &stuffs) {
    cout << endl;
    for (auto i : stuffs) { cout << i << " "; }
    cout << endl;
}

// node.h

class Node {
public:
    static int noOfNodes;
    int suffixNode;
    Node() :
        suffixNode(-1) {};

    ~Node() {
    }
};

// edge.h

struct Key {
    int nodeID;
    int asciiChar;
};

class Edge {
public:
    int startNode;
    int endNode;
    int startLabelIndex;
    int endLabelIndex;

    vector<Key> children;
    Key ID;
    Key parentID;

    void insert();
    void remove();
    static long returnHashKey(int node, int c);
    static Edge findEdge(int node, int c);
    void setID(int node, int c);
    void setParent(int node, int c);

    Edge() : startNode(-1) {};
    Edge(int start, int first, int last) :
        startNode(start),
        endNode(Node::noOfNodes++),
        startLabelIndex(first),
        endLabelIndex(last)
    {
        ID.asciiChar = -1;
        ID.nodeID = -1;
        parentID.nodeID = -1;
        parentID.asciiChar = -1;
    };

    ~Edge() {}

};

// suffixtree.h

struct Edge_info {
    int start_node, end_node, cumLength;
    Edge e;
};

class suffixTree {
public:
    int rootNode;
    int startIndex;
    int endIndex;

    suffixTree() :
        rootNode(0),
        startIndex(-1),
        endIndex(-1) {};
    suffixTree(int root, int start, int end) :
        rootNode(root),
        startIndex(start),
        endIndex(end) {};
    bool endReal() { return startIndex > endIndex; }
    bool endImg() { return endIndex >= startIndex; }
    void migrateToClosestParent();
};

// edge.cpp
long Edge::returnHashKey(int nodeID, int c) {
    return (nodeID + (((long)c) << 59));
}

void Edge::setID(int node, int c) {
    this->ID.asciiChar = c;
    this->ID.nodeID = node;
}
void Edge::setParent(int node, int c) {
    this->parentID.asciiChar = c;
    this->parentID.nodeID = node;
}

// suffixtree.cpp
string Input;
int inputLength;
int Node::noOfNodes = 1;
Node *nodeArray;
unordered_map<long, Edge> edgeHash;

void printAllEdges() {
    int count = 0;
    cout << "StartNode\tEndNode\tParentStart\tParentEnd" << endl;
    for (auto it = edgeHash.begin(); it != edgeHash.end(); it++) {
        Edge parent = Edge::findEdge(it->second.parentID.nodeID, it->second.parentID.asciiChar);
        cout << it->second.startNode << "\t\t" << it->second.endNode
            //             << "\t\t" << nodeArray[it->second.endNode].suffixNode
            //             << "\t\t" << it->second.startLabelIndex
            //             << "\t\t" << it->second.endLabelIndex
            << "\t\t" << parent.startNode
            << "\t\t" << parent.endNode << " ";
        count++;
        int head;
        if (inputLength > it->second.endLabelIndex)
            head = it->second.endLabelIndex;
        else
            head = inputLength;
        for (int i = it->second.startLabelIndex; i < head + 1; i++)
            cout << Input[i];
        cout << endl;
    }
    cout << "Total edges: " << count << endl;
    for (auto it = edgeHash.begin(); it != edgeHash.end(); it++) {
        Edge parent = Edge::findEdge(it->second.parentID.nodeID, it->second.parentID.asciiChar);
        cout << "Parent is: " << parent.startNode << " " << parent.endNode << endl;
        for (Key childID : parent.children) {
            Edge child = Edge::findEdge(childID.nodeID, childID.asciiChar);
            cout << child.startNode << "\t\t" << child.endNode << " ";
        }
        cout << endl;
    }

}

void Edge::insert() {
    long key = returnHashKey(startNode, Input[startLabelIndex]);
    edgeHash[key] = *this;
}

void Edge::remove() {
    long key = returnHashKey(startNode, Input[startLabelIndex]);
    edgeHash.erase(key);
}

Edge Edge::findEdge(int node, int asciiChar) {
    long key = returnHashKey(node, asciiChar);
    unordered_map<long, Edge>::const_iterator search = edgeHash.find(key);

    if (search != edgeHash.end()) {
        return edgeHash.at(key);
    }

    return Edge();
}

void suffixTree::migrateToClosestParent() {

    if (endReal()) {
    }
    else {
        Edge e = Edge::findEdge(rootNode, Input[startIndex]);

        if (e.startNode == -1) {
            cout << rootNode << " " << startIndex << " " << Input[startIndex] << endl;
        }
        assert(e.startNode != -1);
        int labelLength = e.endLabelIndex - e.startLabelIndex;

        while (labelLength <= (endIndex - startIndex)) {
            startIndex += labelLength + 1;
            rootNode = e.endNode;
            if (startIndex <= endIndex) {
                e = Edge::findEdge(e.endNode, Input[startIndex]);
                if (e.startNode == -1) {
                    cout << rootNode << " " << startIndex << " " << Input[startIndex] << endl;
                }
                assert(e.startNode != -1);
                labelLength = e.endLabelIndex - e.startLabelIndex;
            }
        }

    }
}

int breakEdge(const suffixTree &s, Edge &e) {

    e.remove();

    Edge *newEdge = new Edge(s.rootNode, e.startLabelIndex,
        e.startLabelIndex + s.endIndex - s.startIndex);

    nodeArray[newEdge->endNode].suffixNode = s.rootNode;
    e.startLabelIndex += s.endIndex - s.startIndex + 1;
    e.startNode = newEdge->endNode;
    e.setID(e.startNode, Input[e.startLabelIndex]);

    newEdge->children.push_back(e.ID);
    newEdge->setID(newEdge->startNode, Input[newEdge->startLabelIndex]);
    if (newEdge->startNode > 0) {
        newEdge->setParent(nodeArray[newEdge->startNode].suffixNode, Input[s.startIndex - 1]);
    }
    newEdge->insert();


    for (Key id : e.children) {
        Edge child = Edge::findEdge(id.nodeID, id.asciiChar);
        assert(child.ID.nodeID != -1);
        child.setParent(e.ID.nodeID, e.ID.asciiChar);
        child.insert();
    }

    e.setParent(newEdge->ID.nodeID, newEdge->ID.asciiChar);
    e.insert();

    return newEdge->endNode;
}

void carryPhase(suffixTree &tree, int lastIndex) {

    int parentNode;
    int previousParentNode = -1;
    while (true) {
        Edge e;
        parentNode = tree.rootNode;

        if (tree.endReal()) {
            e = Edge::findEdge(tree.rootNode, Input[lastIndex]);
            if (e.startNode != -1)
                break;
        }
        else {
            e = Edge::findEdge(tree.rootNode, Input[tree.startIndex]);
            int diff = tree.endIndex - tree.startIndex;
            if (Input[e.startLabelIndex + diff + 1] == Input[lastIndex])
                break;
            parentNode = breakEdge(tree, e);
        }

        if (previousParentNode > 0) {
            nodeArray[previousParentNode].suffixNode = parentNode;
        }

        Edge *newEdge = new Edge(parentNode, lastIndex, inputLength);
        newEdge->setID(newEdge->startNode, Input[newEdge->startLabelIndex]);


        if (e.startNode == -1 && newEdge->startNode > 0) {
            Edge parent = Edge::findEdge(nodeArray[newEdge->startNode].suffixNode, Input[tree.startIndex - 1]);
            if (parent.startNode != -1) {
                newEdge->setParent(parent.ID.nodeID, parent.ID.asciiChar);
                parent.children.push_back(newEdge->ID);
                parent.insert();
            }
        }

        if (e.startNode != -1 && newEdge->startNode > 0) {
            Edge parent = Edge::findEdge(e.parentID.nodeID, e.parentID.asciiChar);
            newEdge->setParent(parent.ID.nodeID, parent.ID.asciiChar);
            parent.children.push_back(newEdge->ID);
            parent.insert();
        }

        newEdge->insert();
        previousParentNode = parentNode;
        if (tree.rootNode == 0)
            tree.startIndex++;
        else {
            tree.rootNode = nodeArray[tree.rootNode].suffixNode;
        }

        tree.migrateToClosestParent();
    }

    if (previousParentNode > 0)
        nodeArray[previousParentNode].suffixNode = parentNode;
    tree.endIndex++;
    tree.migrateToClosestParent();
}


bool search(string pattern) {
    int len = pattern.length();
    Edge e = Edge::findEdge(0, pattern[0]);
    int iter = 0;
    int i = -1;
    if (e.startNode != -1) {
        while (i < len) {
            iter = 0;
            while (e.endLabelIndex >= e.startLabelIndex + iter) {
                if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
                    iter++;
                    if (i + iter + 1 >= len) {
                        return true;
                    }
                }
                else {
                    return false;
                }
            }
            assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
            e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
            if (e.startNode == -1) {
                return false;
            }
            i += (iter);
        }
    }
    return false;
}

int getIndex(string pattern) {

    int len = pattern.length();
    Edge e = Edge::findEdge(0, pattern[0]);
    int iter = 0;
    int i = -1;
    if (e.startNode != -1) {
        while (i < len) {
            iter = 0;
            while (e.endLabelIndex >= e.startLabelIndex + iter) {
                if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
                    iter++;
                    if (i + iter + 1 >= len) {
                        return e.startLabelIndex + iter - 1;
                    }
                }
                else {
                    return -1;
                }
            }
            assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
            e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
            if (e.startNode == -1) {
                return -1;
            }
            i += (iter);
        }
    }
    return -1;
}

// reset the data structure for next applications
void reset() {
    if (nodeArray) {
        delete nodeArray;
        edgeHash.clear();
        Node::noOfNodes = 0;
        Input.clear();
    }
}

// identifies the first starting edge to the pattern of the tree
Edge_info getStartEdge(const string &pattern) {
    int len = pattern.length();
    Edge e = Edge::findEdge(0, pattern[0]);
    Edge_info start;
    int iter = 0;
    int i = -1;
    int cumLength = 0;
    cumLength += abs(e.startLabelIndex - e.endLabelIndex) + 1;

    if (e.startNode != -1) {
        while (i < len) {
            iter = 0;
            while (e.endLabelIndex >= e.startLabelIndex + iter) {
                if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
                    iter++;
                    if (i + iter + 1 >= len) {
                        start.start_node = e.startNode;
                        start.end_node = e.endNode;
                        start.cumLength = cumLength;
                        start.e = e;
                        return start;
                    }
                }
            }
            assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
            e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
            cumLength += abs(e.startLabelIndex - e.endLabelIndex) + 1;
            if (e.startNode == -1) {
                start.start_node = e.startNode;
                start.end_node = e.endNode;
                start.cumLength = cumLength;
                start.e = e;
                return start;
            }
            i += (iter);
        }
    }
    start.start_node = e.startNode;
    start.end_node = e.endNode;
    start.cumLength = cumLength;
    start.e = e;
    return start;
}

Sanity Testing for two applications: (1) Exact Matching (2) Smallest Lexicographical Rotation

Generalized Suffix Tree: Return the occurrences found within a set of delimited strings.

Last updated

Was this helpful?