Data Structures and Algorithms B
  • Introduction
  • Stable Marrage
  • Huffman Codes
    • ABL
    • Implementation
  • Graph Algorithms
    • Connect Component
    • Bipartiteness
    • Strongly Connected Components
      • Implementation
    • Topological Sort
      • Implementation 1
      • Implementation 2
    • Dijkstra’s Shortest Path
    • Minimum Spanning Tree
      • Prims
        • Implementation 1
      • Kruskels
        • Implementation 1
      • Traveling Salesman Problem
    • k-Clustering
    • Dynamic Programming with Trees
      • Implementation 1
      • Implementation 2
    • Disjoint Sets
      • Implementation
    • Eularian Cycle
      • Implementation
    • Hamiltonian Path
      • Implementation
  • Divide and Conquer
    • Merge Sort
    • Integer Multiplication
    • Closest Pair
    • Master Method
  • Dynamic Programming
    • Interval Scheduling
    • Knapsack Problem
  • Network Flow
    • Maximum Network Flow
    • Improvements
    • Image Segmentation
  • String Algorithms
    • Z Algorithm
    • Boyer-Moore
    • Knuth-Morris-Pratt
    • Suffix Trees
      • Naive Implementation
      • Ukkonens Algorithm
        • Implementation
      • Applications
        • Longest Common Substring
        • Longest Palindromic Substring
        • Longest Repeated Substring
  • Randomized Algorithms
    • Bloom Filters
      • Implementation
Powered by GitBook
On this page

Was this helpful?

  1. String Algorithms
  2. Suffix Trees
  3. Ukkonens Algorithm

Implementation

PreviousUkkonens AlgorithmNextApplications

Last updated 5 years ago

Was this helpful?

This implementation was forked and reworked from . Addtionally, I have also compressed the implementation down to a single file. This way I can grab it whenever I need to mock something up or use it conveniently for leetcode.

#include <iostream>
#include <string>
#include <vector>
#include <unordered_map>
#include <algorithm>
#include <limits>
#include <math.h>
#include <random>
#include <unordered_set>
#include <iomanip>
#include <cassert>


using namespace std;

inline void pause() { cin.ignore(numeric_limits<streamsize>::max(), '\n'); };

template<typename T>
inline void print2d(vector< vector<T> > &stuffs) {
    cout << endl;
    int count = 0;
    for (auto i : stuffs) {
        cout << count++ << ":";
        for (auto j : i)
            cout << setw(2) << j << " ";
        cout << endl;
    }
    cout << endl;
    getchar();
}

template<typename T>
inline void print(vector<T> &stuffs) {
    cout << endl;
    for (auto i : stuffs) { cout << i << " "; }
    cout << endl;
}

// node.h

class Node {
public:
    static int noOfNodes;
    int suffixNode;
    Node() :
        suffixNode(-1) {};

    ~Node() {
    }
};

// edge.h

struct Key {
    int nodeID;
    int asciiChar;
};

class Edge {
public:
    int startNode;
    int endNode;
    int startLabelIndex;
    int endLabelIndex;

    vector<Key> children;
    Key ID;
    Key parentID;

    void insert();
    void remove();
    static long returnHashKey(int node, int c);
    static Edge findEdge(int node, int c);
    void setID(int node, int c);
    void setParent(int node, int c);

    Edge() : startNode(-1) {};
    Edge(int start, int first, int last) :
        startNode(start),
        endNode(Node::noOfNodes++),
        startLabelIndex(first),
        endLabelIndex(last)
    {
        ID.asciiChar = -1;
        ID.nodeID = -1;
        parentID.nodeID = -1;
        parentID.asciiChar = -1;
    };

    ~Edge() {}

};

// suffixtree.h

struct Edge_info {
    int start_node, end_node, cumLength;
    Edge e;
};

class suffixTree {
public:
    int rootNode;
    int startIndex;
    int endIndex;

    suffixTree() :
        rootNode(0),
        startIndex(-1),
        endIndex(-1) {};
    suffixTree(int root, int start, int end) :
        rootNode(root),
        startIndex(start),
        endIndex(end) {};
    bool endReal() { return startIndex > endIndex; }
    bool endImg() { return endIndex >= startIndex; }
    void migrateToClosestParent();
};

// edge.cpp
long Edge::returnHashKey(int nodeID, int c) {
    return (nodeID + (((long)c) << 59));
}

void Edge::setID(int node, int c) {
    this->ID.asciiChar = c;
    this->ID.nodeID = node;
}
void Edge::setParent(int node, int c) {
    this->parentID.asciiChar = c;
    this->parentID.nodeID = node;
}

// suffixtree.cpp
string Input;
int inputLength;
int Node::noOfNodes = 1;
Node *nodeArray;
unordered_map<long, Edge> edgeHash;

void printAllEdges() {
    int count = 0;
    cout << "StartNode\tEndNode\tParentStart\tParentEnd" << endl;
    for (auto it = edgeHash.begin(); it != edgeHash.end(); it++) {
        Edge parent = Edge::findEdge(it->second.parentID.nodeID, it->second.parentID.asciiChar);
        cout << it->second.startNode << "\t\t" << it->second.endNode
            //             << "\t\t" << nodeArray[it->second.endNode].suffixNode
            //             << "\t\t" << it->second.startLabelIndex
            //             << "\t\t" << it->second.endLabelIndex
            << "\t\t" << parent.startNode
            << "\t\t" << parent.endNode << " ";
        count++;
        int head;
        if (inputLength > it->second.endLabelIndex)
            head = it->second.endLabelIndex;
        else
            head = inputLength;
        for (int i = it->second.startLabelIndex; i < head + 1; i++)
            cout << Input[i];
        cout << endl;
    }
    cout << "Total edges: " << count << endl;
    for (auto it = edgeHash.begin(); it != edgeHash.end(); it++) {
        Edge parent = Edge::findEdge(it->second.parentID.nodeID, it->second.parentID.asciiChar);
        cout << "Parent is: " << parent.startNode << " " << parent.endNode << endl;
        for (Key childID : parent.children) {
            Edge child = Edge::findEdge(childID.nodeID, childID.asciiChar);
            cout << child.startNode << "\t\t" << child.endNode << " ";
        }
        cout << endl;
    }

}

void Edge::insert() {
    long key = returnHashKey(startNode, Input[startLabelIndex]);
    edgeHash[key] = *this;
}

void Edge::remove() {
    long key = returnHashKey(startNode, Input[startLabelIndex]);
    edgeHash.erase(key);
}

Edge Edge::findEdge(int node, int asciiChar) {
    long key = returnHashKey(node, asciiChar);
    unordered_map<long, Edge>::const_iterator search = edgeHash.find(key);

    if (search != edgeHash.end()) {
        return edgeHash.at(key);
    }

    return Edge();
}

void suffixTree::migrateToClosestParent() {

    if (endReal()) {
    }
    else {
        Edge e = Edge::findEdge(rootNode, Input[startIndex]);

        if (e.startNode == -1) {
            cout << rootNode << " " << startIndex << " " << Input[startIndex] << endl;
        }
        assert(e.startNode != -1);
        int labelLength = e.endLabelIndex - e.startLabelIndex;

        while (labelLength <= (endIndex - startIndex)) {
            startIndex += labelLength + 1;
            rootNode = e.endNode;
            if (startIndex <= endIndex) {
                e = Edge::findEdge(e.endNode, Input[startIndex]);
                if (e.startNode == -1) {
                    cout << rootNode << " " << startIndex << " " << Input[startIndex] << endl;
                }
                assert(e.startNode != -1);
                labelLength = e.endLabelIndex - e.startLabelIndex;
            }
        }

    }
}

int breakEdge(const suffixTree &s, Edge &e) {

    e.remove();

    Edge *newEdge = new Edge(s.rootNode, e.startLabelIndex,
        e.startLabelIndex + s.endIndex - s.startIndex);

    nodeArray[newEdge->endNode].suffixNode = s.rootNode;
    e.startLabelIndex += s.endIndex - s.startIndex + 1;
    e.startNode = newEdge->endNode;
    e.setID(e.startNode, Input[e.startLabelIndex]);

    newEdge->children.push_back(e.ID);
    newEdge->setID(newEdge->startNode, Input[newEdge->startLabelIndex]);
    if (newEdge->startNode > 0) {
        newEdge->setParent(nodeArray[newEdge->startNode].suffixNode, Input[s.startIndex - 1]);
    }
    newEdge->insert();


    for (Key id : e.children) {
        Edge child = Edge::findEdge(id.nodeID, id.asciiChar);
        assert(child.ID.nodeID != -1);
        child.setParent(e.ID.nodeID, e.ID.asciiChar);
        child.insert();
    }

    e.setParent(newEdge->ID.nodeID, newEdge->ID.asciiChar);
    e.insert();

    return newEdge->endNode;
}

void carryPhase(suffixTree &tree, int lastIndex) {

    int parentNode;
    int previousParentNode = -1;
    while (true) {
        Edge e;
        parentNode = tree.rootNode;

        if (tree.endReal()) {
            e = Edge::findEdge(tree.rootNode, Input[lastIndex]);
            if (e.startNode != -1)
                break;
        }
        else {
            e = Edge::findEdge(tree.rootNode, Input[tree.startIndex]);
            int diff = tree.endIndex - tree.startIndex;
            if (Input[e.startLabelIndex + diff + 1] == Input[lastIndex])
                break;
            parentNode = breakEdge(tree, e);
        }

        if (previousParentNode > 0) {
            nodeArray[previousParentNode].suffixNode = parentNode;
        }

        Edge *newEdge = new Edge(parentNode, lastIndex, inputLength);
        newEdge->setID(newEdge->startNode, Input[newEdge->startLabelIndex]);


        if (e.startNode == -1 && newEdge->startNode > 0) {
            Edge parent = Edge::findEdge(nodeArray[newEdge->startNode].suffixNode, Input[tree.startIndex - 1]);
            if (parent.startNode != -1) {
                newEdge->setParent(parent.ID.nodeID, parent.ID.asciiChar);
                parent.children.push_back(newEdge->ID);
                parent.insert();
            }
        }

        if (e.startNode != -1 && newEdge->startNode > 0) {
            Edge parent = Edge::findEdge(e.parentID.nodeID, e.parentID.asciiChar);
            newEdge->setParent(parent.ID.nodeID, parent.ID.asciiChar);
            parent.children.push_back(newEdge->ID);
            parent.insert();
        }

        newEdge->insert();
        previousParentNode = parentNode;
        if (tree.rootNode == 0)
            tree.startIndex++;
        else {
            tree.rootNode = nodeArray[tree.rootNode].suffixNode;
        }

        tree.migrateToClosestParent();
    }

    if (previousParentNode > 0)
        nodeArray[previousParentNode].suffixNode = parentNode;
    tree.endIndex++;
    tree.migrateToClosestParent();
}


bool search(string pattern) {
    int len = pattern.length();
    Edge e = Edge::findEdge(0, pattern[0]);
    int iter = 0;
    int i = -1;
    if (e.startNode != -1) {
        while (i < len) {
            iter = 0;
            while (e.endLabelIndex >= e.startLabelIndex + iter) {
                if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
                    iter++;
                    if (i + iter + 1 >= len) {
                        return true;
                    }
                }
                else {
                    return false;
                }
            }
            assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
            e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
            if (e.startNode == -1) {
                return false;
            }
            i += (iter);
        }
    }
    return false;
}

int getIndex(string pattern) {

    int len = pattern.length();
    Edge e = Edge::findEdge(0, pattern[0]);
    int iter = 0;
    int i = -1;
    if (e.startNode != -1) {
        while (i < len) {
            iter = 0;
            while (e.endLabelIndex >= e.startLabelIndex + iter) {
                if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
                    iter++;
                    if (i + iter + 1 >= len) {
                        return e.startLabelIndex + iter - 1;
                    }
                }
                else {
                    return -1;
                }
            }
            assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
            e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
            if (e.startNode == -1) {
                return -1;
            }
            i += (iter);
        }
    }
    return -1;
}

// reset the data structure for next applications
void reset() {
    if (nodeArray) {
        delete nodeArray;
        edgeHash.clear();
        Node::noOfNodes = 0;
        Input.clear();
    }
}

// identifies the first starting edge to the pattern of the tree
Edge_info getStartEdge(const string &pattern) {
    int len = pattern.length();
    Edge e = Edge::findEdge(0, pattern[0]);
    Edge_info start;
    int iter = 0;
    int i = -1;
    int cumLength = 0;
    cumLength += abs(e.startLabelIndex - e.endLabelIndex) + 1;

    if (e.startNode != -1) {
        while (i < len) {
            iter = 0;
            while (e.endLabelIndex >= e.startLabelIndex + iter) {
                if (Input[e.startLabelIndex + iter] == pattern[i + iter + 1]) {
                    iter++;
                    if (i + iter + 1 >= len) {
                        start.start_node = e.startNode;
                        start.end_node = e.endNode;
                        start.cumLength = cumLength;
                        start.e = e;
                        return start;
                    }
                }
            }
            assert(iter = (e.endLabelIndex - e.startLabelIndex + 1));
            e = Edge::findEdge(e.endNode, pattern[i + iter + 1]);
            cumLength += abs(e.startLabelIndex - e.endLabelIndex) + 1;
            if (e.startNode == -1) {
                start.start_node = e.startNode;
                start.end_node = e.endNode;
                start.cumLength = cumLength;
                start.e = e;
                return start;
            }
            i += (iter);
        }
    }
    start.start_node = e.startNode;
    start.end_node = e.endNode;
    start.cumLength = cumLength;
    start.e = e;
    return start;
}

Sanity Testing for two applications: (1) Exact Matching (2) Smallest Lexicographical Rotation

// exactMatching.cpp
void DFS_subtree(Edge curEdge, int cumLength, std::vector<int> &occurrence, int textLength) {
    for (Key childID : curEdge.children) {
        Edge child = Edge::findEdge(childID.nodeID, childID.asciiChar);
        int newCumLength =
            cumLength +
            abs(child.startLabelIndex - child.endLabelIndex) + 1;

        DFS_subtree(child, newCumLength, occurrence, textLength);

        if (child.children.empty()) {
            occurrence.push_back(((textLength - 1) - newCumLength) + 1);
        }

    }
}

vector<int> exactStringMatching(const string text, const string pattern) {
    // reset data structures
    reset();

    Input = text + '$';
    inputLength = Input.length() - 1;
    nodeArray = reinterpret_cast<Node *> (malloc(2 * inputLength * sizeof(Node)));

    suffixTree tree(0, 0, -1);
    for (int i = 0; i <= inputLength; i++) {
        carryPhase(tree, i);
    }

    std::vector<int> occurrences;

    // found - iterate to first occurrence
    Edge_info start = getStartEdge(pattern);
    Edge parent = Edge::findEdge(start.e.parentID.nodeID, start.e.parentID.asciiChar);
    // edgecase: occurance is immediately at leaf branch
    if (start.e.children.empty()) {
        if (start.start_node == 0) {
            occurrences.push_back(start.e.startLabelIndex);
        }
        else if (start.e.children.empty() && start.start_node > 0) {
            int found = getIndex(pattern) - pattern.length() + 1;
            if (found != -1)
                occurrences.push_back(found);
        }
    }
    // check if otherwise exists: then DFS through subtree
    else if (start.start_node != -1) {
        DFS_subtree(start.e, start.cumLength, occurrences, Input.length());
    }

    return occurrences;
}


// lexicoSmallest.cpp

pair<int, int> linearTraversalLex(Edge minChild, const size_t N, int cumStart) {
    std::pair<int, int> minLexIndex;

    int endIndex = cumStart - 1;
    while (cumStart < N) {
        // find next min edge
        char minChar = std::numeric_limits<char>::max();
        int minNodeId;
        for (Key childID : minChild.children) {
            if (childID.asciiChar < minChar && childID.asciiChar != '$') {
                minChar = childID.asciiChar;
                minNodeId = childID.nodeID;
            }
        }
        minChild = Edge::findEdge(minNodeId, minChar);
        cumStart += abs(minChild.startLabelIndex - minChild.endLabelIndex) + 1;
        endIndex = minChild.endLabelIndex;
    }

    int overflow = abs(cumStart - int(N));
    minLexIndex.second = abs(endIndex - overflow);
    minLexIndex.first = minLexIndex.second - N + 1;
    return minLexIndex;
}

string smallestLexicographicRotation(const string &s) {
    // reset data structures
    reset();

    Input = s + s + std::numeric_limits<char>::max();
    inputLength = Input.length() - 1;
    nodeArray = reinterpret_cast<Node *>(malloc(2 * inputLength * (sizeof(Node))));

    suffixTree tree(0, 0, -1);
    for (int i = 0; i <= inputLength; i++) {
        carryPhase(tree, i);
    }

    char min_char = std::numeric_limits<char>::max();
    for (char c : s) {
        if (c < min_char)
            min_char = c;
    }

    Edge root = Edge::findEdge(0, min_char);
    int cumStart = abs(root.startLabelIndex - root.endLabelIndex) + 1;

    std::pair<int, int> startEnd = linearTraversalLex(root, s.size(), cumStart);
    return Input.substr(startEnd.first, startEnd.second - startEnd.first + 1);
}

// sanity checking - exactMatching
#define NUM_TEST_CASES_EXT1 200
#define RANDOM_STRING_SIZE_EXT1 20

std::vector<int> bruteForceOccurances(const string text, const string pattern) {

    std::vector<int> occurances;
    if (pattern.empty()) return occurances;

    const size_t pLen = pattern.length();
    const int go_to = text.size() - pLen;

    for (int i = 0; i <= go_to; i++) {
        string temp = text.substr(i, pLen);
        if (text.substr(i, pLen) == pattern) occurances.push_back(i);
    }

    return occurances;
}

void exactMatchingSanityCheck() {
    cout << "exactMatchingSanityCheck Starting" << endl;
    const std::string Alphabet = "abcdefghijklmnopqrstuvwxyz";
    std::default_random_engine seed;
    std::uniform_int_distribution<int> alphabetGenerator(0, 25);
    inputLength = RANDOM_STRING_SIZE_EXT1;

    // generate random Text
    Input.clear();
    for (int i = 0; i < NUM_TEST_CASES_EXT1; i++) {
        string test;
        for (int i = 0; i < RANDOM_STRING_SIZE_EXT1; i++)
            test.push_back('a' + alphabetGenerator(seed));

        /*Generate a pattern*/
        std::uniform_int_distribution<int> startGenerator(0, RANDOM_STRING_SIZE_EXT1 - 1);
        int start = startGenerator(seed);
        std::uniform_int_distribution<int> lenGenerator(0, RANDOM_STRING_SIZE_EXT1 - start - 1);
        int end = lenGenerator(seed);

        // make sure that our pattern is not longer than our input string
        assert((start + end) < test.length());
        std::string P = test.substr(start, end);

        // testing results
        std::vector<int> occurancesBF = bruteForceOccurances(test, P);
        std::vector<int> occurancesUK = exactStringMatching(test, P);
        sort(occurancesUK.begin(), occurancesUK.end());

        assert(occurancesUK.size() == occurancesBF.size());

        for (int i = 0; i < occurancesBF.size(); i++) {
            assert(occurancesBF.at(i) == occurancesUK.at(i));
        }
        // reset data structures
        Input.clear();
        test.clear();
        occurancesBF.clear();
        occurancesUK.clear();
        cout << i << " ";
    }
    cout << endl;
    cout << "exactMatchingSanityCheck Complete" << endl << endl;
}

// sanity checking - lexicoSmallest
#define NUM_TEST_CASES_EXT2 1000
#define RANDOM_STRING_SIZE_EXT2 20

string brute_force_min_rotation(const string &text) {
    const std::string doubleText = text + text;
    const size_t goTo = doubleText.length() - text.length();
    std::string minRotation = doubleText.substr(0, text.length());

    string temp;
    for (int i = 1; i <= goTo; i++) {
        temp = doubleText.substr(i, text.length());
        if (temp < minRotation) {
            minRotation = temp;
        }
    }

    return minRotation;
}

void lexicoSmallestSanityTest() {
    cout << "lexicoSmallestSanityTest Starting" << endl;
    const std::string Alphabet = "abcdefghijklmnopqrstuvwxyz";
    std::default_random_engine seed;
    std::uniform_int_distribution<int> alphabetGenerator(0, 25);
    inputLength = RANDOM_STRING_SIZE_EXT2;
    string testStr;

    for (int i = 0; i < NUM_TEST_CASES_EXT2; i++) {
        testStr.clear();
        // generate random Text 
        for (int i = 0; i < RANDOM_STRING_SIZE_EXT2; i++)
            testStr.push_back('a' + alphabetGenerator(seed));

        // testing results
        std::string bf_result = brute_force_min_rotation(testStr);
        std::string uk_result = smallestLexicographicRotation(testStr);

        // expect the same min lex string
        assert(bf_result == uk_result);
        cout << i << " ";
    }
    cout << endl;
    cout << "lexicoSmallestSanityTest Complete" << endl << endl;
}


int main(int argc, char** argv)
{
    exactMatchingSanityCheck();
    lexicoSmallestSanityTest();
}

Generalized Suffix Tree: Return the occurrences found within a set of delimited strings.

vector<int> generalizedExactStringMatching(const vector<string> texts, const string pattern) {
    reset();

    const char unique_delim = '$';
    for (int i = 0; i < texts.size(); i++)
        Input += texts.at(i) + (unique_delim + to_string(i));

    inputLength = Input.length() - 1;
    nodeArray = reinterpret_cast<Node *> (malloc(2 * inputLength * sizeof(Node)));

    suffixTree tree(0, 0, -1);
    for (int i = 0; i <= inputLength; i++) {
        carryPhase(tree, i);
    }

    std::vector<int> occurrences;

    // found - iterate to first occurrence
    Edge_info start = getStartEdge(pattern);
    Edge parent = Edge::findEdge(start.e.parentID.nodeID, start.e.parentID.asciiChar);
    // edgecase: occurance is immediately at leaf branch
    if (start.e.children.empty()) {
        if (start.start_node == 0) {
            occurrences.push_back(start.e.startLabelIndex);
        }
        else if (start.e.children.empty() && start.start_node > 0) {
            int found = getIndex(pattern) - pattern.length() + 1;
            if (found != -1)
                occurrences.push_back(found);
        }
    }
    // check if otherwise exists: then DFS through subtree
    else if (start.start_node != -1) {
        DFS_subtree(start.e, start.cumLength, occurrences, Input.length());
    }

    return occurrences;
}

int main(int argc, char** argv)
{    
    // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
    // b a n a n a $ 0 b a t m a n $ 1
    vector<int> occurances = generalizedExactStringMatching({"banana", "batman"}, "an"); 
    print(occurances); // 1 3 12
}
https://github.com/atuljangra/Ukkonen-SuffixTree