본문 바로가기

Problem Solving/알고리즘

[알고리즘] 최소 스패닝 트리

 최소 스패닝 트리 (Minimum Spanning Tree, MST)는 모든 정점을 연결하면서 가장 작은 가중치를 가지는 트리를 만드는 알고리즘이다. 트리를 만드는 알고리즘이므로 이 그래프는 사이클이 없는 DAG 형태여야 한다. 이 문제를 해결하는 데에는 두 가지 알고리즘이 대표적으로 사용되는데, 바로 크루스칼 알고리즘과 프림 알고리즘이다. 일단은 크루스칼 알고리즘을 설명한 뒤에 프림 알고리즘을 구현할 때가 되면 그때 와서 추가해 놓겠다.


크루스칼 알고리즘

 크루스칼 알고리즘은 모든 간선 중 가장 가중치가 작은 간선부터 차례로 연결시키는데, 만약 새로 연결하는 간선 때문에 사이클이 생긴다면 그 간선을 포기한다. 원리도 굉장히 간단하고 구현도 굉장히 간단하다. 근데 Greedy한 알고리즘 답게 증명은 간단하지 않다. 정당성의 증명을 PS계의 교과서 종만북에서는 다음과 같이 하고 있다. 

 -Greedy하게 선택하면 절대로 손해를 보지 않는다.

 by 귀류법: 크루스칼 알고리즘이 선택하는 간선 중 MST에 포함되지 않는 간선이 있고, 이 중 처음으로 선택되는 간선을 $e(u,v)$라고 하자. MST는 이 간선을 포함하지 않으니 $u$와 $v$는 다른 간선으로 연결되어 있을 것이다. 그런데 크루스칼 알고리즘은 반드시 가중치가 낮은 것부터 골라잡으니 이 $u$와 $v$를 연결하는 다른 간선은 크루스칼 알고리즘에 선택된 간선보다 가중치가 크거나 적어도 같을 것이다. 따라서 그냥 이 간선을 없애 버리고 크루스칼 알고리즘이 찾아준 간선을 연결하면 손해는 보지 않으면서 여전히 스패닝 트리이므로 원 명제는 항상 정당하다.

 -항상 최적의 선택만을 내려도 전체 문제의 최적해를 얻을 수 있다.

 이건 꽤나 자명하다. 어떤 MST에 하나의 노드를 추가한다고 하면

 그 노드와 연결되는 모든 간선 중 최소치를 연결한 것도 MST 아닌가. 따라서 항상 최적의 부분해를 구하면 그것이 전체 문제의 최적해로 이어진다.


어쨌던 간에 증명을 했으니 구현을 해야 한다. 구현을 할 때 생각해야 하는 것은

1. 가장 가중치가 작은 간선부터 차례로 연결

2. 사이클이 생기는지 검사


1번은 그냥 단순히 가중치 기준으로 오름차순으로 정렬하면 되고, 문제는 2번이다. 사이클이 생기는지 검사한다는 것은 현재 만들어진 트리에 속하는 두 노드가 다시 이어지는지 검사하는 것이다. 따라서 만약 사이클이 생긴다면 연결하려고 하는 두 노드가 이미 트리에 속해있다는 것을 뜻한다. 결론은, 두 노드가 이미 한 집합 안에 속하는지만 체크해 주면 된다는 것이다! 우리는 이럴 때 쓸 수 있는 아주 좋은 자료구조를 이미 알고 있다. Disjoint Set, 흔히 Union Find Tree라고 불리우는 바로 그 자료구조 맞다. 

구현은 아래 소스코드를 참조하자.


1197 최소 스패닝 트리

문제: https://www.acmicpc.net/problem/1197


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <cstdio>
#include <vector>
#include <algorithm>
 
using namespace std;
 
typedef struct DisJointSet
{
    vector<int> parent;
    vector<int> ra;
    DisJointSet(int n)
    {
        parent.assign(n+1,0);
        ra.assign(n+1,1);
        for(int i=0;i<=n;i++)
            parent[i]=i;
    }
    int fi(int x)
    {
        if(parent[x]==x) return x;
        else return parent[x]=fi(parent[x]);
    }
    void un(int a, int b)
    {
        a=fi(a); b=fi(b);
        if(a==b) return;
        if(ra[a]<ra[b]) swap(a,b);
        parent[b]=a;
        if(ra[a]==ra[b]) ra[a]++;
    }
}DJS;
 
int main()
{
    int n,m,sum=0;
    vector<pair<int,pair<int,int> > > E;
    scanf("%d %d",&n,&m);
    DJS S(n);
    for(int i=0;i<m;i++)
    {
        int t1,t2,w;
        scanf("%d %d %d",&t1,&t2,&w);
        E.push_back(make_pair(w,make_pair(t1,t2)));
    }
    sort(E.begin(),E.end());
    for(int i=0;i<m;i++)
    {
        int weight=E[i].first;
        int left=E[i].second.first;
        int right=E[i].second.second;
        if(S.fi(left)==S.fi(right)) continue;
        sum+=weight;
        S.un(left,right);
    }
    printf("%d",sum);
    return 0;
}
 
cs