Posts Least Common Ancestor (LCA)
Post
Cancel

Least Common Ancestor (LCA)

Where is this useful?

The Least Common Ancestor (LCA) data structure is useful wherever you have a directed graph where every vertex has out-degree \(\leq 1\). In more common terms, each vertex has a unique determined ‘parent’, or it is a root node, with no parent. The most common (and almost always only) example being a rooted tree.

On these particular graphs, the LCA gives us a fast way to move ‘up’ the graph (Towards your parents). In particular, we can use this to find the least common ancestor in \(\log (N)\) time, where the data structure gets its name from.

Reusing the analogy of parenting vertices, a vertex \(u\) is an ancestor of \(v\) if \(u\) is \(v\)’s parent, or \(v\)’s parent’s parent, and so on. As long as there is a line of ‘parentage’ connecting \(v\) to \(u\), \(u\) is an ancestor of \(v\). We consider \(v\) to also be it’s own ancestor.

The least common ancestor problem then requires, given two vertices \(x\) and \(y\), to find a vertex \(z\) in the graph such that \(z\) is an ancestor of both \(x\) and \(y\), but there is no vertex \(z’ \neq z\) such that \(z\) is an ancestor of \(z’\) and \(z’\) is an ancestor of \(x\) and \(y\) (In the tree example, we just want to find the lowest depth vertex whose subtree contains boths \(x\) and \(y\)).

Note that the least common ancestor can be \(x\) or \(y\), if \(x\) is an ancestor of \(y\) or vice-versa.

Implementing the Data Structure

Interface

Let’s start by defining an interface for this data structure, and then slowly implement our methods.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class LCA:
"""
vertices are represented as numbers 0->n-1.
"""

    def __init__(self, n_vertices):
        self.n = n_vertices
        self.adjacent = [[] for _ in range(self.n)]

    def add_edge(self, u, v, weight=1):
        self.adjacent[u].append((v, weight))
        self.adjacent[v].append((u, weight))

    def build(self, root):
        # Once edges are added, build the tree/data structure.
        pass # TODO

    def query(self, u, v, root=None):
        # What is the lowest common ancestor of u, v?
        # Extension: Make this query from any root vertex you want.
        pass # TODO

    def dist(self, u, v):
        # Find the distance between two vertices - very simple if we have LCA.
        pass # TODO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
template<typename T = int> struct LCA {
    // vertices are represented as numbers 0->n-1.
    int n; vector<vector<pair<int, T> > > adjacent;

    LCA(int n_vertices) : n(n_vertices), adjacent(n) { }

    void add_edge(int u, int v, T weight=1) {
        adjacent[u].emplace_back(v, weight);
        adjacent[v].emplace_back(u, weight);
    }

    void build(int root=0) {
        // Once edges are added, build the tree/data structure.
        // TODO
    }

    int query(int u, int v, int root=-1) {
        // What is the lowest common ancestor of u, v?
        // Extension: Make this query from any root vertex you want.
        // TODO
    }

    T dist(int u, int v) {
        // Find the distance between two vertices - very simple if we have LCA.
        // TODO
    }
}

Useful data

First off, let’s save some intermediary data that will make our life a lot easier, and strictly define the tree structure. We’ll introduce three arrays: parent, level and length.

  • parent stores the direct parent of any vertex in the rooted tree.
  • level stores the level of the tree the vertex is at (Number of edges from it to the root)
  • length stores the length of the vertex to the root (Using edge weights).

We’ll populate these fields in the build method, since all edges should be added by then.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class LCA:
"""
vertices are represented as numbers 0->n-1.
"""

    def __init__(self, n_vertices):
        self.n = n_vertices
        self.adjacent = [[] for _ in range(self.n)]

    def add_edge(self, u, v, weight=1):
        self.adjacent[u].append((v, weight))
        self.adjacent[v].append((u, weight))

    def dfs(self, source, c_parent, c_level, c_length):
        # Search from the source down the tree and set parent, level, length accordingly.
        self.parent[source] = c_parent
        self.level[source] = c_level
        self.length[source] = c_length
        for child, weight in self.adjacent[source]:
            if child != c_parent:
                self.dfs(child, source, c_level + 1, c_length + weight)

    def build(self, root):
        # Once edges are added, build the tree/data structure.
        self.parent = [None]*self.n
        self.level = [None]*self.n
        self.length = [None]*self.n
        self.dfs(root, -1, 0, 0)

    def query(self, u, v, root=None):
        # What is the lowest common ancestor of u, v?
        # Extension: Make this query from any root vertex you want.
        pass # TODO

    def dist(self, u, v):
        # Find the distance between two vertices - very simple if we have LCA.
        pass # TODO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
template<typename T = int> struct LCA {
    // vertices are represented as numbers 0->n-1.
    int n; vector<vector<pair<int, T> > > adjacent;
    vi parent, level;
    vector<T> length;

    LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }

    void add_edge(int u, int v, T weight=1) {
        adjacent[u].emplace_back(v, weight);
        adjacent[v].emplace_back(u, weight);
    }

    void dfs(int source, int c_parent, int c_level, T c_length) {
        // Search from the source down the tree and set parent, level, length accordingly.
        parent[source] = c_parent;
        level[source] = c_level;
        length[source] = c_length;
        for (auto v: adjacent[source])
            if (v.first != c_parent)
                dfs(v.first, source, c_level+1, c_length+v.second);
    }

    void build(int root=0) {
        // Once edges are added, build the tree/data structure.
        dfs(root, -1, 0, 0);
    }

    int query(int u, int v, int root=-1) {
        // What is the lowest common ancestor of u, v?
        // Extension: Make this query from any root vertex you want.
        // TODO
    }

    T dist(int u, int v) {
        // Find the distance between two vertices - very simple if we have LCA.
        // TODO
    }
}

So now we can query many useful characteristics of vertices in rooted trees. Now for the interesting part: let’s start creating data unique to the LCA structure.

Ancestor Array

LCA gets its fast queries by precomputing a special array, called ancestor. Ancestor is a 2 dimensional array with ancestor[v][k] storing the ancestor of vertex v \(2^k\) edges towards the root. As an example, ancestor[v][0] is parent[v] (Parent is just ancestor 1 edge towards the root), and ancestor[v][1] is parent[parent[v]] where appropriate (2 edges towards root is same as parent’s parent).

If you just populated this array by searching up the tree \(2^k\) steps each time, you’d have worst case complexity \(O(n^2)\) to build the array. Luckily, we can use the fact that ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1] (In other words, you can move \(2^k\) steps towards the root by first moving \(2^{k-1}\) steps, which we’ve already computed, and then another \(2^{k-1}\) steps from this new position). This reduces the complexity to \(O(n\log_2(n))\)

We do this so that we can find the ancestor \(m\) edges towards the root for any arbitrary \(m\) in \(\log_2(m)\) time, while only using \(\log_2(n)\) space. We’ll see how this gets done later.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class LCA:
"""
vertices are represented as numbers 0->n-1.
"""

    # number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
    MAX_LOG = 20

    def __init__(self, n_vertices):
        self.n = n_vertices
        self.adjacent = [[] for _ in range(self.n)]

    def add_edge(self, u, v, weight=1):
        self.adjacent[u].append((v, weight))
        self.adjacent[v].append((u, weight))

    def dfs(self, source, c_parent, c_level, c_length):
        # Search from the source down the tree and set parent, level, length accordingly.
        self.parent[source] = c_parent
        self.level[source] = c_level
        self.length[source] = c_length
        for child, weight in self.adjacent[source]:
            if child != c_parent:
                self.dfs(child, source, c_level + 1, c_length + weight)

    def build(self, root):
        # Once edges are added, build the tree/data structure.
        self.parent = [None]*self.n
        self.level = [None]*self.n
        self.length = [None]*self.n
        self.dfs(root, -1, 0, 0)
        # NEW: Compute ancestor
        self.ancestor = [[-1]*self.MAX_LOG for _ in range(self.n)]
        # Initial step: ancestor[v][0] = parent[v]
        for v in range(self.n):
            self.ancestor[v][0] = self.parent[v]
        # Now, compute ancestor[v][k] from 1->MAX_LOG
        for k in range(1, self.MAX_LOG):
            for v in range(self.n):
                if self.ancestor[v][k-1] != -1:
                    # Move 2^{k-1} up, then 2^{k-1} again.
                    self.ancestor[v][k] = self.ancestor[self.ancestor[v][k-1]][k-1]

    def query(self, u, v, root=None):
        # What is the lowest common ancestor of u, v?
        # Extension: Make this query from any root vertex you want.
        pass # TODO

    def dist(self, u, v):
        # Find the distance between two vertices - very simple if we have LCA.
        pass # TODO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
template<typename T = int> struct LCA {
    // vertices are represented as numbers 0->n-1.
    // number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
    int MAX_LOG = 20;
    int n; vector<vector<pair<int, T> > > adjacent;
    vi parent, level;
    vvi ancestor;
    vector<T> length;

    LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }

    void add_edge(int u, int v, T weight=1) {
        adjacent[u].emplace_back(v, weight);
        adjacent[v].emplace_back(u, weight);
    }

    void dfs(int source, int c_parent, int c_level, T c_length) {
        // Search from the source down the tree and set parent, level, length accordingly.
        parent[source] = c_parent;
        level[source] = c_level;
        length[source] = c_length;
        for (auto v: adjacent[source])
            if (v.first != c_parent)
                dfs(v.first, source, c_level+1, c_length+v.second);
    }

    void build(int root=0) {
        // Once edges are added, build the tree/data structure.
        dfs(root, -1, 0, 0);
        // NEW: Compute ancestor
        ancestor.assign(n, vi(MAX_LOG, -1));
        // Initial step: ancestor[v][0] = parent[v]
        for (int v=0; v<n; v++)
            ancestor[v][0] = parent[v];
        // Now, compute ancestor[v][k] from 1->MAX_LOG
        for (int k=1; k < MAX_LOG; k++)
            for (int v=0; v<n; v++)
                if (ancestor[v][k-1] != -1) {
                    // Move 2^{k-1} up, then 2^{k-1} again.
                    ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1];
                }
    }

    int query(int u, int v, int root=-1) {
        // What is the lowest common ancestor of u, v?
        // Extension: Make this query from any root vertex you want.
        // TODO
    }

    T dist(int u, int v) {
        // Find the distance between two vertices - very simple if we have LCA.
        // TODO
    }
}

Query

That’s actually most of the ingenuity out of the way, now we can get to implementing query.

Provided we want the LCA with respect to the root we called build from, we can define the LCA l of u and v in the following way:

l is the ancestor of u and v maximising level[l].

We also know that level[l] <= min(level[u], level[v]). Using this, we can calculate query(u, v) by:

  • Finding the ancestors of u and v (call them a1, a2) such that level[a1] = level[a2] = min(level[u], level[v]).
  • Keep moving a1 and a2 towards the root (higher and higher ancestors) until a1 == a2. Then a1 and a2 are the LCA of u and v.

We can do both of these things on \(\log_2(n)\) time with this ancestor array we’ve generated. Let’s see how:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class LCA:
"""
vertices are represented as numbers 0->n-1.
"""

    # number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
    MAX_LOG = 20

    def __init__(self, n_vertices):
        self.n = n_vertices
        self.adjacent = [[] for _ in range(self.n)]

    def add_edge(self, u, v, weight=1):
        self.adjacent[u].append((v, weight))
        self.adjacent[v].append((u, weight))

    def dfs(self, source, c_parent, c_level, c_length):
        # Search from the source down the tree and set parent, level, length accordingly.
        self.parent[source] = c_parent
        self.level[source] = c_level
        self.length[source] = c_length
        for child, weight in self.adjacent[source]:
            if child != c_parent:
                self.dfs(child, source, c_level + 1, c_length + weight)

    def build(self, root):
        # Once edges are added, build the tree/data structure.
        self.parent = [None]*self.n
        self.level = [None]*self.n
        self.length = [None]*self.n
        self.dfs(root, -1, 0, 0)
        self.ancestor = [[-1]*self.MAX_LOG for _ in range(self.n)]
        # Initial step: ancestor[v][0] = parent[v]
        for v in range(self.n):
            self.ancestor[v][0] = self.parent[v]
        # Now, compute ancestor[v][k] from 1->MAX_LOG
        for k in range(1, self.MAX_LOG):
            for v in range(self.n):
                if self.ancestor[v][k-1] != -1:
                    # Move 2^{k-1} up, then 2^{k-1} again.
                    self.ancestor[v][k] = self.ancestor[self.ancestor[v][k-1]][k-1]

    def query(self, u, v, root=None):
        # What is the lowest common ancestor of u, v?
        # Extension: Make this query from any root vertex you want.
        if root is not None:
            pass # TODO
        # assume that u is higher up than v, to simplify the code below
        if self.level[u] > self.level[v]:
            u, v = v, u
        # STEP 1: set u and v to be ancestors with the same level
        for k in range(self.MAX_LOG-1, -1, -1):
            if (self.level[v] - (1 << k) >= self.level[u]):
                # If v is 2^k levels below u, move it up 2^k levels.
                v = self.ancestor[v][k]
        # We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.
        # Do we need to move to step 2?
        if (u == v): return u
        # STEP 2: find the highest ancestor where u != v. Then the parent is the LCA
        for k in range(self.MAX_LOG-1, -1, -1):
            if (self.ancestor[u][k] != self.ancestor[v][k]):
                # Move up 2^k steps
                u = self.ancestor[u][k]
                v = self.ancestor[v][k]
        return self.parent[u]

    def dist(self, u, v):
        # Find the distance between two vertices - very simple if we have LCA.
        pass # TODO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
template<typename T = int> struct LCA {
    // vertices are represented as numbers 0->n-1.
    // number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
    int MAX_LOG = 20;
    int n; vector<vector<pair<int, T> > > adjacent;
    vi parent, level;
    vvi ancestor;
    vector<T> length;

    LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }

    void add_edge(int u, int v, T weight=1) {
        adjacent[u].emplace_back(v, weight);
        adjacent[v].emplace_back(u, weight);
    }

    void dfs(int source, int c_parent, int c_level, T c_length) {
        // Search from the source down the tree and set parent, level, length accordingly.
        parent[source] = c_parent;
        level[source] = c_level;
        length[source] = c_length;
        for (auto v: adjacent[source])
            if (v.first != c_parent)
                dfs(v.first, source, c_level+1, c_length+v.second);
    }

    void build(int root=0) {
        // Once edges are added, build the tree/data structure.
        dfs(root, -1, 0, 0);
        // NEW: Compute ancestor
        ancestor.assign(n, vi(MAX_LOG, -1));
        // Initial step: ancestor[v][0] = parent[v]
        for (int v=0; v<n; v++)
            ancestor[v][0] = parent[v];
        // Now, compute ancestor[v][k] from 1->MAX_LOG
        for (int k=1; k < MAX_LOG; k++)
            for (int v=0; v<n; v++)
                if (ancestor[v][k-1] != -1) {
                    // Move 2^{k-1} up, then 2^{k-1} again.
                    ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1];
                }
    }

    int query(int u, int v, int root=-1) {
        // What is the lowest common ancestor of u, v?
        // Extension: Make this query from any root vertex you want.
        if (root != -1) {
            // TODO
        }
        // assume that u is higher up than v, to simplify the code below
        if (level[u] > level[v]) swap(u, v);
        // STEP 1: set u and v to be ancestors with the same level
        for (int k=MAX_LOG-1, k>=0; k--)
            if (level[v] - (1 << k) >= level[u]) {
                // If v is 2^k levels below u, move it up 2^k levels.
                v = ancestor[v][k];
            }
        // We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.
        // Do we need to move to step 2?
        if (u == v) return u
        // STEP 2: find the highest ancestor where u != v. Then the parent is the LCA
        for (int k=MAX_LOG; k>=0; k--)
            if (ancestor[u][k] != ancestor[v][k]) {
                // Move up 2^k steps
                u = ancestor[u][k];
                v = ancestor[v][k];
            }
        return parent[u];
    }

    T dist(int u, int v) {
        // Find the distance between two vertices - very simple if we have LCA.
        // TODO
    }
}

Nice! That’s the main functionality of LCA completed.

Corrolaries

Let’s quickly tackle the two remaining implementations:

  • Calculating the distance between two vertices u and v is the same as calculating the distance between u and query(u, v), and adding that to the distance between v and query(u, v)
  • Calculating the LCA from a particular root, just requires a slight change in perspective. For two vertices u and v, and custom root r, the LCA will always be one of query(u, v), query(u, r) or query(v, r).
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class LCA:
"""
vertices are represented as numbers 0->n-1.
"""

    # number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
    MAX_LOG = 20

    def __init__(self, n_vertices):
        self.n = n_vertices
        self.adjacent = [[] for _ in range(self.n)]

    def add_edge(self, u, v, weight=1):
        self.adjacent[u].append((v, weight))
        self.adjacent[v].append((u, weight))

    def dfs(self, source, c_parent, c_level, c_length):
        # Search from the source down the tree and set parent, level, length accordingly.
        self.parent[source] = c_parent
        self.level[source] = c_level
        self.length[source] = c_length
        for child, weight in self.adjacent[source]:
            if child != c_parent:
                self.dfs(child, source, c_level + 1, c_length + weight)

    def build(self, root):
        # Once edges are added, build the tree/data structure.
        self.parent = [None]*self.n
        self.level = [None]*self.n
        self.length = [None]*self.n
        self.dfs(root, -1, 0, 0)
        self.ancestor = [[-1]*self.MAX_LOG for _ in range(self.n)]
        # Initial step: ancestor[v][0] = parent[v]
        for v in range(self.n):
            self.ancestor[v][0] = self.parent[v]
        # Now, compute ancestor[v][k] from 1->MAX_LOG
        for k in range(1, self.MAX_LOG):
            for v in range(self.n):
                if self.ancestor[v][k-1] != -1:
                    # Move 2^{k-1} up, then 2^{k-1} again.
                    self.ancestor[v][k] = self.ancestor[self.ancestor[v][k-1]][k-1]

    def query(self, u, v, root=None):
        # What is the lowest common ancestor of u, v?
        # Extension: Make this query from any root vertex you want.
        if root is not None:
            # NEW: Custom root -- see diagrams below for reasoning.
            a = self.query(u, v)
            b = self.query(u, root)
            c = self.query(v, root)
            # Case 1: root is in the same component as u when `a` is removed from the tree. So `b` is the LCA
            if (a == c and c != b) return b
            # Case 2: root is in the same component as v when `a` is removed from the tree. So `a` is the LCA
            if (a == b and c != b) return c
            # Case 3: b and c are above a in the tree. So return a
            return a
        # assume that u is higher up than v, to simplify the code below
        if self.level[u] > self.level[v]:
            u, v = v, u
        # STEP 1: set u and v to be ancestors with the same level
        for k in range(self.MAX_LOG-1, -1, -1):
            if (self.level[v] - (1 << k) >= self.level[u]):
                # If v is 2^k levels below u, move it up 2^k levels.
                v = self.ancestor[v][k]
        # We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.
        # Do we need to move to step 2?
        if (u == v): return u
        # STEP 2: find the highest ancestor where u != v. Then the parent is the LCA
        for k in range(self.MAX_LOG-1, -1, -1):
            if (self.ancestor[u][k] != self.ancestor[v][k]):
                # Move up 2^k steps
                u = self.ancestor[u][k]
                v = self.ancestor[v][k]
        return self.parent[u]

    def dist(self, u, v):
        # NEW: Find the distance between two vertices
        return self.length[u] + self.length[v] - 2 * self.length[self.query(u, v)]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
template<typename T = int> struct LCA {
    // vertices are represented as numbers 0->n-1.
    // number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
    int MAX_LOG = 20;
    int n; vector<vector<pair<int, T> > > adjacent;
    vi parent, level;
    vvi ancestor;
    vector<T> length;

    LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }

    void add_edge(int u, int v, T weight=1) {
        adjacent[u].emplace_back(v, weight);
        adjacent[v].emplace_back(u, weight);
    }

    void dfs(int source, int c_parent, int c_level, T c_length) {
        // Search from the source down the tree and set parent, level, length accordingly.
        parent[source] = c_parent;
        level[source] = c_level;
        length[source] = c_length;
        for (auto v: adjacent[source])
            if (v.first != c_parent)
                dfs(v.first, source, c_level+1, c_length+v.second);
    }

    void build(int root=0) {
        // Once edges are added, build the tree/data structure.
        dfs(root, -1, 0, 0);
        // NEW: Compute ancestor
        ancestor.assign(n, vi(MAX_LOG, -1));
        // Initial step: ancestor[v][0] = parent[v]
        for (int v=0; v<n; v++)
            ancestor[v][0] = parent[v];
        // Now, compute ancestor[v][k] from 1->MAX_LOG
        for (int k=1; k < MAX_LOG; k++)
            for (int v=0; v<n; v++)
                if (ancestor[v][k-1] != -1) {
                    // Move 2^{k-1} up, then 2^{k-1} again.
                    ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1];
                }
    }

    int query(int u, int v, int root=-1) {
        // What is the lowest common ancestor of u, v?
        // Extension: Make this query from any root vertex you want.
        if (root != -1) {
            // NEW: Custom root -- see diagrams below for reasoning.
            int a = query(u, v);
            int b = query(u, root);
            int c = query(v, root);
            // Case 1: root is in the same component as u when `a` is removed from the tree. So `b` is the LCA
            if (a == c and c != b) return b;
            // Case 2: root is in the same component as v when `a` is removed from the tree. So `a` is the LCA
            if (a == b and c != b) return c;
            // Case 3: b and c are above a in the tree. So return a
            return a;
        }
        // assume that u is higher up than v, to simplify the code below
        if (level[u] > level[v]) swap(u, v);
        // STEP 1: set u and v to be ancestors with the same level
        for (int k=MAX_LOG-1, k>=0; k--)
            if (level[v] - (1 << k) >= level[u]) {
                // If v is 2^k levels below u, move it up 2^k levels.
                v = ancestor[v][k];
            }
        // We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.
        // Do we need to move to step 2?
        if (u == v) return u
        // STEP 2: find the highest ancestor where u != v. Then the parent is the LCA
        for (int k=MAX_LOG; k>=0; k--)
            if (ancestor[u][k] != ancestor[v][k]) {
                // Move up 2^k steps
                u = ancestor[u][k];
                v = ancestor[v][k];
            }
        return parent[u];
    }

    T dist(int u, int v) {
        // NEW: Find the distance between two vertices
        return length[u] + length[v] - 2 * length[query(u, v)];
    }
}

And that’s our implementation done! Now get out there and solve some problems!

Related Problems

This post is licensed under GNU GPL V3 by the author.