Problem Solving/문제풀이

[BOJ] 15404 Divide and Conquer

Ryute 2020. 8. 10. 04:28

문제 요약

$N$개의 정점이 있고, 이 정점을 잇는 스패닝 트리가 2개 있다. 이 두 스패닝 트리의 합집합으로 구성되는 그래프에 대해(간선은 구별된다), $K$개의 간선을 제거하면 그래프를 연결 그래프가 아니도록 할 수 있다. 이 때 $K$의 최솟값을 구하고, 그 경우의 수를 $10^9 + 7$로 나눈 나머지를 출력하면 된다.

문제 풀이

가장 먼저 해야 하는 관찰은 다음과 같다.

  • $K=2$이거나, $K=3$.

증명은 다음과 같다.

  • 임의의 정점에 연결된 모든 간선을 제거하면 그 정점은 다른 정점과 연결되어 있지 않기에, 그래프는 연결 그래프가 아니다. 따라서 $K$는 모든 정점의 degree 중 최솟값보다 클 수 없다. 간선 하나는 정확히 두 개의 정점의 degree를 1만큼 증가시키므로, 모든 정점의 degree의 총합은 $2\times2\times(N-1)=4N-4$이다. 비둘기집의 원리에 의해 $\displaystyle \frac{4N-4}{N} < 4$이므로 $K \ge 4$ 일 수 없다.
  • 각 스패닝 트리의 관점에서 볼 때, 모든 임의의 컴포넌트는 서로 연결되어 있으므로 이를 분리시키는 데 각 스패닝 트리 당 최소한 하나의 간선을 제거해야 하고, 이는 서로 독립이다. 따라서 $K>1$이다.

이 관찰을 하고 나면, 자명한 다음 사실을 활용하여 풀이를 확장시킬 수 있다.

  • 제거하는 간선들이 원래 어느 스패닝 트리에 속했는지를 생각하면 $K=2$일 때는 두 스패닝 트리에서 하나씩, $K=3$일 때는 한 스패닝 트리에서 한 개, 나머지 스패닝 트리에서 두 개를 제거해야 한다.

이는 한 스패닝 트리에서 모든 간선을 제거하면 나머지 스패닝 트리가 연결 그래프를 구성할 테니 직관적으로 이해할 수 있다. 위 정리를 다시 쓰면, 두 스패닝 트리 중 적어도 하나에서는 정확히 한 개의 간선을 제거한다는 뜻이 된다. 트리에서 하나의 간선을 제거한다는 뜻은, 트리를 어떤 서브트리와 나머지 정점들로 나눈다는 뜻과 같다. (이때 서브트리의 루트를 $T$라고 하면, $T$와 $T$의 부모를 잇는 간선이 제거된 간선이 된다.) 따라서 두 스패닝 트리 각각에 대해 모든 임의의 간선을 제거해본다고 하자. 그렇다면 정답은 (각 간선을 제거했을 때 나머지 스패닝 트리에서 간선을 최대 $K-1$개 제거해 서브트리와 나머지 정점을 분리할 수 있는 경우의 수)를 모든 경우에 대해 더한 것이 된다. $K=2$인 경우에는 두 간선을 제거하는 것이 구분되지 않으므로, 이를 2로 나누어 주어야 한다.

간선을 하나 제거할 스패닝 트리를 $A$로 부르고, 다른 하나의 스패닝 트리를 $B$라고 하자. 잘 생각해 보면, $A$에서 한 간선을 제거하면 $B$에서 간선을 제거할 수 있는 경우의 수는 반드시 최대 하나로 고정된다는 것을 알 수 있다. 이를 정리하면 다음과 같이 표현할 수 있다.

  • $A$에서 간선 하나를 제거해서 만들어진 서브트리를 $T$라고 하자. $T$와 $T^c = A-T$를 잇는 간선(당연히 $B$에 속한다)의 집합을 $E$라 할 때, $|E|\ge 3$이면 $K>3$이 되므로 가능한 경우가 없다. $|E|=1,2$면 각각 $K=2,3$인 경우에 대응되고, 이 때 모든 간선을 반드시 제거해야 하므로 경우의 수는 하나로 유일하다.

따라서 우리는 문제를 $T$와 $T^c$를 잇는 간선이 몇 개나 있는지 세는 문제로 바꾸어 줄 수 있다. $B$에 속한 어떤 간선이 $T$와 $T^c$를 잇기 위해서는, 한쪽 끝은 $T$에 속하고 다른 한 끝은 $T^c$에 속해야 한다. 이 두 끝을 $p$와 $q$라고 하자. 트리 위에서 두 정점을 잇는 경로는 유일하므로, $A$에서 $p$와 $q$를 잇는 경로는 반드시 끊긴 간선인 $T$의 루트와 그 부모를 잇는 간선(이하 $e$라 호칭)을 지난다. 다시 말해, $B$를 구성하는 모든 간선중 그 간선의 양 끝점을 잇는 $A$ 위의 경로가 $e$를 지나는 간선의 수를 찾아야 한다.

이를 Naive하게 해결하는 방법은 다음과 같다. 실제로 $B$의 각 간선에 대해서 양쪽 끝을 잡고, $A$ 위에서 한쪽 끝에서부터 다른쪽 끝으로 이동하면서 만나는 모든 간선에 1을 더해 준다. 그러면 $A$의 각 간선에는 $B$의 간선이 표현하는 경로 중 몇 개의 경로가 그 간선을 경유하는지 여부가 적혀 있게 된다. 하지만 이는 너무 비효율적이니, 다음과 같이 개선할 수 있다. $LCA(A,B)=L$이라고 하면 $B$의 하나의 간선이 갱신하는 $A$의 간선들은 $A \rightarrow L$과 $B \rightarrow L$의 경로 두 개로 쪼개진다. 따라서 트리 위에서 변홧값 배열을 사용하면 각 LCA를 구하는데 $O(\lg N)$이, $B$가 표현하는 경로 하나에 모두 1을 더해 주는 데에 $O(1)$이 걸리니 우리가 구하고 싶은 값을 $O(N \lg N)$에 전처리해둘 수 있다.

이 값이 전처리된다면 문제를 해결하는 것은 쉽다. 모든 간선에 대해 그 간선을 끊었다고 가정하고, 그 간선에 써져 있는 값이 1이나 2일 때 해당하는 $K$에 대한 경우의 수를 추가해 주면 된다. $K=2$인 경우의 수가 0일 경우에만 자동적으로 $K=3$이 된다.

소스 코드

#include <bits/stdc++.h>
#include <random>
#include <cassert>
#define all(x) (x).begin(), (x).end()
#define endl "\n"
#define ends " "
#define pb(x) push_back(x)
#define fio() ios_base::sync_with_stdio(0); cin.tie(0)
#define fileio() ifstream file_in("input.txt");ofstream file_out("output.txt")
/*#define cin file_in
#define cout file_out*/

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tpi;
typedef tuple<ll, ll, ll> tpl;
typedef pair<double, ll> pdl;

const int MOD = 1000000009;
const ll LMOD = 1000000009;
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const double pi = acos(-1);
const double eps = 1e-10;
const int dx[] = { 0,1,0,-1 };
const int dy[] = { 1,0,-1,0 };

int n;
vector<int> adj[2][101010];
int dep[2][101010], pa[2][101010][20], c[2][101010];
int ans[10];

void dfs(int x, int h, int par, int d) {
    dep[x][h] = d;
    for (int th : adj[x][h]) {
        if (th == par) continue;
        dfs(x, th, h, d + 1);
        pa[x][th][0] = h;
    }
}

void process(int x) {
    pa[x][1][0] = 1;
    dfs(x, 1, 1, 1);
    for (int i = 1; i <= 19; i++)
        for (int j = 1; j <= n; j++)
            pa[x][j][i] = pa[x][pa[x][j][i - 1]][i - 1];
}

int lca(int x, int a, int b) {
    if (dep[x][a] > dep[x][b]) swap(a, b);
    int t = dep[x][b] - dep[x][a];
    for (int i = 19; i >= 0; i--)
        if ((1<<i) & t) b = pa[x][b][i];
    if (a == b) return a;
    for (int i = 19; i >= 0; i--) {
        if (pa[x][a][i] == pa[x][b][i]) continue;
        a = pa[x][a][i], b = pa[x][b][i];
    }
    return pa[x][a][0];
}

int dfs2(int x, int h, int par) {
    int s = 0;
    for (int th : adj[x][h]) {
        if (th == par) continue;
        s += dfs2(x, th, h);
    }
    c[x][h] = c[x][h] + s;
    if (c[x][h] <= 2) ans[c[x][h]]++;
    return c[x][h];
}

void solve(int x) {
    process(x);
    for (int i = 1; i <= n; i++) {
        for (auto h : adj[1 - x][i]) {
            int a = i, b = h;
            if (a > b) continue;
            int l = lca(x, a, b);
            c[x][a]++, c[x][b]++, c[x][l] -= 2;
        }
    }
    dfs2(x, 1, 1);
}

int main() {
    fio();
    cin >> n;
    for(int i=0;i<2;i++)
        for (int j = 0; j < n - 1; j++) {
            int t1, t2;
            cin >> t1 >> t2;
            adj[i][t1].push_back(t2);
            adj[i][t2].push_back(t1);
        }
    solve(0); solve(1);
    if (ans[1]) cout << "2 " << ans[1] / 2;
    else cout << "3 " << ans[2];
    return 0;
}