본문 바로가기

CS ﹒ Algorithm

UnionFind 알고리즘과 경로 압축, 그리고 관련 문제

 

UnionFind 알고리즘은 서로소 집합을 표현하기 위한 알고리즘이다.

일반적으로 MST(Minumum Spanning Tree)를 구할 때 사용한다.

정말 별 거 없다.

어떤 노드 6개가 있다고 가정해보자.

 

 

이 때, 1-2, 2-3, 4-5, 5-6이 연결되어 있다면 다음과 같은 그림이 될 것이다.

 

 

위 그림에서 1-2-3과 4-5-6은 서로소 집합이다.

이걸 코드로 표현하기 위한 게 UnionFind 알고리즘이고.

 

UnionFind 알고리즘의 핵심은 두 노드를 합치는 Union 연산과, 합친 노드간의 부모 노드를 찾아주는 Find 연산이다.

그리고 추가적으로 경로 압축까지만 알면 UnionFind 알고리즘은 끝.

세가지를 빠르게 마스터해보자.

 

우선 집합을 나타내기 위해 배열이 필요하다.

다시 처음으로 돌아가자.

 

 

배열 이름은 맘대로 하세요.

나는 parents라는 이름을 주로 사용하고 있다.

parents는 각 노드의 부모노드가 무엇인지를 나타낸다.

처음에는 어떤 노드도 union() 연산을 수행하지 않았으므로 각 노드 자신이 부모 노드가 된다.

 

이제 다시 위와 같이 1-2, 2-3이라는 집합이 주어졌다고 가정해보자.

 

 

첫 번째로 1번 노드와 2번 노드의 union() 연산을 수행한다.

union() 연산 시에는 큰 값을 가진 노드가 작은 값을 가진 노드를 부모로 가지게 만들어준다.

 

 

위에서는 각 노드의 부모노드가 자기 자신이였기 때문에 노드 2가 바로 1번을 바라보게 되었지만, 이번에는 경우가 다르다.

union() 연산 시에는 각 노드의 부모노드 간에 합치는 과정을 거치게 되므로, 3번 노드는 2번이 아닌 2번의 부모 노드인 1번 노드에 합류하게 된다.

 

지금까지의 과정을 코드로 표현하면 아래와 같다.

 

1
2
3
4
5
6
7
8
9
10
    private static void union(int nodeA, int nodeB, int[] parent) {
       int parentA = find(nodeA);
       int parentB = find(nodeB);
 
        if (parentA < parentB) {
            parent[parentB] = parentA;
        } else {
            parent[parentA] = parentB;
        }
    }
cs

 

쉽다 쉽다.

 

이제 find() 연산에 대해 이해해보자.

find() 연산은 루트 노드를 찾아내기 위한 연산이다.

굉장히 간단하므로 코드를 먼저 읽고 글로 설명하겠다.

 

1
2
3
4
5
   private static int find(int node, int[] parents) {
       if (node == parents[node]) return lan;
 
       return find(parents[node], parents);
    }
cs

 

find(3)을 실행했다고 가정해보자.

 

1. parents[3]은 1이므로 node 3과 일치하지 않는다. 즉, node 3에게는 부모 노드가 있다는 뜻이다.

2. parents[node]는 1이므로 find(1)을 실행한다.

3. 1번 노드는 자기 자신이 부모 노드이므로 node == parents[node]를 만족하고 1을 반환한다.

 

너무 쉽죠?

이제 UnionFind 알고리즘을 마스터한 것 같다.

근데 여기에는 함정이 있다.

아까와 다르게 역순으로 union() 연산을 실행해보자.

이번에는 1-2, 2-3, 3-4, 5-6이 모두 연결되어 있다고 가정할 것이다.

 

 

 

아까랑 형태가 달라졌다.

루트노드부터 순차적으로 union() 연산을 수행하게 된다면 모두 루트 노드를 바로 바라보기 때문에 find() 연산의 시간복잡도가 O(1)이 되지만, 반대로 모든 집합이 역순으로 주어질 경우 O(n)이라는 시간 복잡도를 가지게 된다.

 

ㅋㅋ 너무 대충 썼네..

 

이를 완화하여 시간복잡도를 최악의 경우에도 O(logN)으로 유지하기 위한 기법이 경로 압축이다.

방법은 굉장히 단순하다, find() 연산을 수행할 때마다 그 결과를 parents[node]의 값으로 수정해준다면 어떨까?

 

1
2
3
4
5
    private static int find(int node, int[] parents) {
        if (node == parents[node]) return node;
 
        return parents[node] = find(parents[node], parents);
    }
cs

 

이렇게 코드를 작성하면 find() 연산을 실행할 때마다 노드의 경로를 정렬해주는 효과를 가지게 되어 시간복잡도가 비약적으로 상승한다.

이게 경로 압축 기법이다.

 

union(),find(),경로 압축까지 이해했으면 UnionFind 알고리즘은 끝이다.

바로 문제를 풀어보자.

 

 

 

https://www.acmicpc.net/problem/1043

 

전형적인 유니온 파인드 알고리즘 문제다.

 

1. 사람 수 N만큼 parents 배열을 초기화한다.

2. 진실을 아는 사람들의 번호를 저장한다.

3. 각 줄에 주어지는 파티를 union() 연산을 통해 집합으로 묶는다.

4. 2에서 저장했던 진실을 아는 사람들의 번호에 대해 find() 연산을 실행해서 해당 파티에 진실을 아는 사람이 있는지 구한다.

5. 총 파티 수에서 find()로 구한 값만큼 제거하면 거짓말을 할 수 있는 파티의 수를 구할 수 있다.

 

정답 코드는 아래에.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
 
public class Main {
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] input = br.readLine().split(" ");
 
        int countOfPeoples = Integer.parseInt(input[0]);
        int countOfParties = Integer.parseInt(input[1]);
        int[] parents = initializeParents(countOfPeoples);
 
        input = br.readLine().split(" ");
        int countOfKnownPeoples = Integer.parseInt(input[0]);
        int[] cautions = new int[countOfKnownPeoples];
        for (int i=0; i<countOfKnownPeoples; i++) {
            cautions[i] = Integer.parseInt(input[i+1]);
        }
 
        List<int[]> parties = new ArrayList<>();
        getParties(br, countOfParties, parties);
 
        for (int[] party : parties) {
            int firstPeople = party[0];
 
            for (int i=1; i<party.length; i++) {
                union(firstPeople, party[i], parents);
            }
        }
 
        int answer = countOfParties;
        for (int[] party : parties) {
            for (int j=0; j<cautions.length; j++) {
                int partyA = find(party[0], parents);
                int partyB = find(cautions[j], parents);
 
                if (partyA == partyB) {
                    answer--;
                    break;
                }
            }
        }
 
        System.out.println(answer);
    }
 
    private static int[] initializeParents(int countOfParties) {
        int[] parents = new int[countOfParties +1];
        for (int i=0; i<parents.length; i++) {
            parents[i] = i;
        }
        return parents;
    }
 
    private static void getParties(BufferedReader br, int countOfParties, List<int[]> parties) throws IOException {
        String[] input;
        for (int i = 0; i< countOfParties; i++) {
            input = br.readLine().split(" ");
            int count = Integer.parseInt(input[0]);
            int[] party = new int[count];
            for (int j=0; j<count; j++) {
                party[j] = Integer.parseInt(input[j+1]);
            }
 
            parties.add(party);
        }
    }
 
    private static void union(int peopleA, int peopleB, int[] parents) {
        int partyA = find(peopleA, parents);
        int partyB = find(peopleB, parents);
 
        if (partyA < partyB) {
            parents[partyB] = partyA;
        } else {
            parents[partyA] = partyB;
        }
    }
 
    private static int find(int people, int[] parents) {
        if (people == parents[people]) return people;
 
        return parents[people] = find(parents[people], parents);
    }
}
cs

 

 

추가적으로 풀어볼만한 문제

1. 여행 가자(골드4, https://www.acmicpc.net/problem/1976)

-> 정답 코드 (https://github.com/yangddoddi/Self_Study/tree/main/%EB%B0%B1%EC%A4%80/Gold/1976.%E2%80%85%EC%97%AC%ED%96%89%E2%80%85%EA%B0%80%EC%9E%90)

2. 집합의 표현(골드4, https://www.acmicpc.net/problem/1717)

-> 정답 코드

(https://github.com/yangddoddi/Self_Study/tree/main/%EB%B0%B1%EC%A4%80/Gold/1717.%E2%80%85%EC%A7%91%ED%95%A9%EC%9D%98%E2%80%85%ED%91%9C%ED%98%84)