Problem Solving/문제풀이

[BOJ] 1693 트리 색칠하기

Ryute 2018. 10. 22. 00:02


문제 링크: https://www.acmicpc.net/problem/1693

  • 문제 요약

$n$개의 정점으로 구성된 트리가 있다. 이 트리를 $1$에서 $n$번까지의 색깔로 칠하는데 $k$번 색깔로 칠하는 데는 $k$만큼의 비용이 든다. 인접한 정점을 같은 색으로 칠하지 않으려고 할 때 최소 비용을 구하는 문제이다.

  • 풀이 과정

이 문제를 푸는 데 가장 핵심적인 아이디어는, 트리를 색칠하는 데 많아 봐야 $\lg n$개의 색깔만이 필요하다는 것이다. 맨 처음에는 직관적으로 이 성질을 생각했는데, 좀 더 생각해보고 증명을 해 볼 수 있었다. 증명은 아래에!

만약 트리를 색칠하는 데 필요한 색깔이 생각보다 적다는 것을 알게 되면, 다음과 같은 동적계획법을 생각해 볼 수 있다. 

$DP[a][c]=$ $a$번째 정점을 색깔 $c$로 칠했을 때 $a$번째 정점을 루트로 하는 서브트리를 칠하는 데 드는 최소 비용

$\lg n=q$라고 두자. 그러면 $i$번째 정점의 자식들이 $A[1] ... A[k]$로 총 $k$개 존재한다고 할 때, $$DP[a][c]= \sum_{i=1}^{k} \: \min_{j=1}^{q} \: \begin{cases}  DP[A[i]][j] \; (j \neq c) \\ \infty \; (j=c) \end{cases} $$ 이라고 할 수 있다. 이를 계산하면 각 함수 호출마다 색깔에 대해서 항상 $q$번 반복하고 정점은 모든 호출을 합쳐 $O(nq)$번만 방문하므로 amortized 비슷하게 해서 총 시간복잡도가 $O(nq^2)$임을 알 수 있다.

그러면 이제 트리를 색칠하는 데 많아 봐야 $\lg n$개의 색깔만이 필요하다는 것을 증명하는 것 만이 남았다. 만약 어떤 트리가 있는데 그 아래 노드들은 어떻게 모두 색칠을 했고 루트 노드만 색칠을 해야 한다고 하자. 루트 노드에 색칠할 수 있는 가장 싼 색깔이 p라고 하자. 그러면 루트 노드의 자식 노드들은 자명하게 $1$~$p-1$번 색깔들을 모두 하나씩 가지고 있을 것이다. 여기서 트리의 크기에 대해 관찰해보면 다음과 같이 수학적 귀납법을 사용할 수 있다. $p=1$일 경우 정점이 하나일 것이기 때문에 이 트리의 크기는 1이다. $p$일때 트리의 크기를 $T[p]$라고 하면, $ T[p+1] \geq {\sum_{i=1}^{p} T[i]}+1 $ 이고 따라서 $p$가 1씩 증가할 때마다 $T[p]$는 항상 2배를 넘게 증가한다는 것을 알 수 있다. 여기서 $T[p]$가 $n$이므로 $p$는 $q$와 경향성이 같다. 완전히 같지는 않겠지만, 대강 실제로 색칠하는 데 필요한 색깔은 $\lg n$개보다 작거나 같음을 알 수 있겠다. 

이 모든 것들을 종합하여 트리 DP로 문제를 해결할 수 있다.

  • 소스 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <bits/stdc++.h>
 
using namespace std;
typedef vector<int> vi;
const int INF=987654321;
 
vector<vi> adj,dir;
int n;
 
int check[100005];
void dfs(int x)
{
    check[x]=1;
    for(int i=0;i<adj[x].size();i++)
    {
        int there=adj[x][i];
        if(check[there]) continue;
        dir[x].push_back(there);
        dfs(there);
    }
}
 
int cache[100005][105];
int dp(int idx, int c)
{
    int& ret=cache[idx][c];
    if(ret!=-1return ret;
    int s=0;
    for(int i=0;i<dir[idx].size();i++)
    {
        int there=dir[idx][i];
        int mi=INF;
        for(int j=1;j<=18;j++)
            if(c!=j) mi=min(mi,dp(there,j));
        s+=mi;
    }
    return ret=s+c;
}
 
int main()
{
    memset(cache,-1,sizeof(cache));
    scanf("%d",&n);
    adj.assign(n+1,vi());
    dir.assign(n+1,vi());
    for(int i=0;i<n-1;i++)
    {
        int t1,t2;
        scanf("%d %d",&t1,&t2);
        adj[t1].push_back(t2);
        adj[t2].push_back(t1);
    }
    dfs(1);
    int mi=INF;
    for(int i=1;i<=18;i++)
        mi=min(mi,dp(1,i));
    printf("%d",mi);
    return 0;
}
 
cs