Saturday, 25 January 2025

Using 'Union find' pattern for problem solving questions

 About the pattern

The pattern is used to group elements into sets based on specified properties. Each element is unique and the sets are non-overlapping meaning each set has distinct elements. Each set forms a tree data structure, each element has a parent and the root of the tree is called the representative element of that set. The parent of this representative node is the same node (itself). If we pick any element of a set and follow its parent node then we will always reach the representative element of the set (root of the tree representing the disjoint set).

The pattern is implemented using two methods
  1. find(node): Find the representative of the set containing the node.
  2. union(node1, node2): Merges the set containing node1 and node 2 into one.

Let's say we have an array representing numbers 0 to 5. Before we start applying this pattern each element has itself as the parent.

This means we have 6 different unique sets each containing one element. Now we want to start grouping them into a unique set (typically we will have more than one unique set which we will see with an actual example in some time but for this example consider  we want to create a single set).


  1. Do union(0, 1): This will merge nodes 0 and 1 together. Change the representative of 0 to 1.


  2. Similarly do
    1. union(1,2)
    2. union(2,3)
    3. union(4,5): Notice how we merged 4 and 5 instead of 3 and 4. This is done just to give you an idea that an element can merge into any disjoint sets based on the property under consideration.
    4. union(3,5)
At the end, we have a tree below:


As you would have imagined by now as you are making this disjoint set the size of the tree can be O(N) in the worst case and so is the time complexity of this pattern.

Optimization: We can optimize this further by maintaining the rank of each element which would denote the number of the child node beneath it. So the rank of the representative node will denote the size of the tree it represents (the number of child nodes it has). We can use this rank to then decide which of the two representative nodes should we use as the parent new representative node while merging the two trees corresponding to the two disjoint sets/trees. Eg., above node 5 at the end of iteration has the rank of 6 (5 nodes under it  + 1 counting itself).

With the above optimization, you will not select a new representative node as the representative node with a higher rank among the two getting merged. So it will be something like below:




As you would have guessed by now the tree is balanced with the above approach and guarantees that the TC for search is reduced to log(N) - length of the tree.

So if I have to find the representative of 4, I will find its parent which is 5, and then recursively find its parent till we reach root which in this case is 1. Once the root is reached (node with itself as a parent) we return that node as it is the representative node. As we traverse the length of the tree the TC is log(N).

Using 'Union find' pattern

Now let's try to use this pattern to solve an actual problem-solving question

Problem statement:
For a given integer, n, and an array, edges, return the number of connected components in a graph containing n nodes. 
The array edges[i] = [x, y] indicates that there’s an edge between x and y in the graph.

Constraint:
  • 1<=n<=1000
  • 0<=edges.length<=500
  • edges[I].length == 2
  • 0<=x,y<n
  • x!=y
  • No repeated edges
Notice how the problem statement says that the elements (vertices of the graph) are unique, and there are no repeated edges. Let's see how we can implement this using the union-find method.

Solution:
class UnionFind:

    def __init__(self, n):
        self.parent = []
        for i in range(n):
            self.parent.append(i)
        self.rank = [1] * n
        self.connected_components = n

    def find(self, v):
        if self.parent[v] != v:
            return self.find(self.parent[v])
        return v
   

    def union(self, x, y):
        p1, p2 = self.find(x), self.find(y)
        if p1 != p2:
          if self.rank[p1] < self.rank[p2]:
            self.parent[p1] = p2
          else:
            self.parent[p2] = p1
          self.connected_components = self.connected_components - 1
		
def count_components(n, edges):
  uf = UnionFind(n)
  for edge in edges:
    v1 = edge[0]
    v2 = edge[1]
    uf.union(v1, v2)  
  return uf.connected_components


If we run the above for the following data sets it gives the correct results:
Example 1:
  • Input: 5 , [[0,1],[1,2],[3,4]] 
  • Output: 2
Example 2:
  • Input: 6 , [[0,1],[3,4],[4,5]]
  • Output: 3
Now let's try to understand the solution and implementation of the pattern. The core logic of the union pattern is in the UnionFind class. In the constructor, we have initialized 3 things
  1. parent - list that tracks the parent of each element
  2. rank - list that tracks the rank of each element
  3. connected_component - number of connected component

At the start each node has itself assigned as the parent and the rank of each node is 1. Similarly, since each node is separate at the start number of connected components is equal to the number of nodes.

As we iterate over the edges we pass them to the union method to merge them into the single set - remember the pattern is used to split unique elements into disjoint sets (connect component in this case). As we merge the vertices which are supposed to be part of edges we reduce the connected_component by 1 since two elements forming an edge are merged in a single set (representing connected component). At the end when we have parsed each edge we would have completed having unique disjoint sets representing the number of connected components in edges, so we simply return connected_componenets which we have been using to track it.

Also, notice how we are using rank to ensure that while emerging two disjoint sets we set the parent as the one with a higher rank to ensure the resultant tree is balanced and consequently find() operations take log(N).


Related links






t> UA-39527780-1 back to top