[BOJ] 1693 트리 색칠하기
문제 링크: https://www.acmicpc.net/problem/1693
- 문제 요약
$n$개의 정점으로 구성된 트리가 있다. 이 트리를 $1$에서 $n$번까지의 색깔로 칠하는데 $k$번 색깔로 칠하는 데는 $k$만큼의 비용이 든다. 인접한 정점을 같은 색으로 칠하지 않으려고 할 때 최소 비용을 구하는 문제이다.
- 풀이 과정
이 문제를 푸는 데 가장 핵심적인 아이디어는, 트리를 색칠하는 데 많아 봐야 $\lg n$개의 색깔만이 필요하다는 것이다. 맨 처음에는 직관적으로 이 성질을 생각했는데, 좀 더 생각해보고 증명을 해 볼 수 있었다. 증명은 아래에!
만약 트리를 색칠하는 데 필요한 색깔이 생각보다 적다는 것을 알게 되면, 다음과 같은 동적계획법을 생각해 볼 수 있다.
$DP[a][c]=$ $a$번째 정점을 색깔 $c$로 칠했을 때 $a$번째 정점을 루트로 하는 서브트리를 칠하는 데 드는 최소 비용
$\lg n=q$라고 두자. 그러면 $i$번째 정점의 자식들이 $A[1] ... A[k]$로 총 $k$개 존재한다고 할 때, $$DP[a][c]= \sum_{i=1}^{k} \: \min_{j=1}^{q} \: \begin{cases} DP[A[i]][j] \; (j \neq c) \\ \infty \; (j=c) \end{cases} $$ 이라고 할 수 있다. 이를 계산하면 각 함수 호출마다 색깔에 대해서 항상 $q$번 반복하고 정점은 모든 호출을 합쳐 $O(nq)$번만 방문하므로 amortized 비슷하게 해서 총 시간복잡도가 $O(nq^2)$임을 알 수 있다.
그러면 이제 트리를 색칠하는 데 많아 봐야 $\lg n$개의 색깔만이 필요하다는 것을 증명하는 것 만이 남았다. 만약 어떤 트리가 있는데 그 아래 노드들은 어떻게 모두 색칠을 했고 루트 노드만 색칠을 해야 한다고 하자. 루트 노드에 색칠할 수 있는 가장 싼 색깔이 p라고 하자. 그러면 루트 노드의 자식 노드들은 자명하게 $1$~$p-1$번 색깔들을 모두 하나씩 가지고 있을 것이다. 여기서 트리의 크기에 대해 관찰해보면 다음과 같이 수학적 귀납법을 사용할 수 있다. $p=1$일 경우 정점이 하나일 것이기 때문에 이 트리의 크기는 1이다. $p$일때 트리의 크기를 $T[p]$라고 하면, $ T[p+1] \geq {\sum_{i=1}^{p} T[i]}+1 $ 이고 따라서 $p$가 1씩 증가할 때마다 $T[p]$는 항상 2배를 넘게 증가한다는 것을 알 수 있다. 여기서 $T[p]$가 $n$이므로 $p$는 $q$와 경향성이 같다. 완전히 같지는 않겠지만, 대강 실제로 색칠하는 데 필요한 색깔은 $\lg n$개보다 작거나 같음을 알 수 있겠다.
이 모든 것들을 종합하여 트리 DP로 문제를 해결할 수 있다.
- 소스 코드
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | #include <bits/stdc++.h> using namespace std; typedef vector<int> vi; const int INF=987654321; vector<vi> adj,dir; int n; int check[100005]; void dfs(int x) { check[x]=1; for(int i=0;i<adj[x].size();i++) { int there=adj[x][i]; if(check[there]) continue; dir[x].push_back(there); dfs(there); } } int cache[100005][105]; int dp(int idx, int c) { int& ret=cache[idx][c]; if(ret!=-1) return ret; int s=0; for(int i=0;i<dir[idx].size();i++) { int there=dir[idx][i]; int mi=INF; for(int j=1;j<=18;j++) if(c!=j) mi=min(mi,dp(there,j)); s+=mi; } return ret=s+c; } int main() { memset(cache,-1,sizeof(cache)); scanf("%d",&n); adj.assign(n+1,vi()); dir.assign(n+1,vi()); for(int i=0;i<n-1;i++) { int t1,t2; scanf("%d %d",&t1,&t2); adj[t1].push_back(t2); adj[t2].push_back(t1); } dfs(1); int mi=INF; for(int i=1;i<=18;i++) mi=min(mi,dp(1,i)); printf("%d",mi); return 0; } | cs |