Problem Solving/문제풀이

[BOJ] 24973 Up Down Subsequence

Ryute 2023. 2. 8. 01:58

문제 요약

길이가 $N$인 수열이 주어진다. 어떤 부분수열을 앞에서부터 읽었을 때 인접한 두 수가 커지면 U, 작아지면 D를 적은 문자열이 있다고 하자. 문자열 $S$가 주어질 때 부분수열을 잘 골라서 만든 문자열과 $S$의 common prefix 길이를 최대화하여라.

문제 풀이

결론부터 말하면, 다음을 반복하면 된다.

1. 문자열을 (연속한 문자 종류, 개수)로 압축한다.
2. 맨 앞이 U면 LIS, D면 LDS를 길이가 그 문자의 개수가 될 때까지 찾는다.
3-1. 만약 찾지 못한다면 LIS나 LDS의 길이(정확히는 길이-1)만큼 추가로 문자를 처리할 수 있고 그 뒤는 불가능하다.
3-2. 만약 찾는다면 연속한 문자를 모두 처리할 수 있다. LIS/LDS의 마지막 위치가 pos라고 하자. 그러면 계속 증가하는 동안/계속 감소하는 동안 pos를 계속 우측으로 민다. 그곳에서부터 다시 1번 과정을 시작한다.

이게 왜 되는지 직관적으로 생각해보자. 

  • 일단 당연히 U만 잔뜩 있으면 LIS 문제와 같으므로 반드시 이 문제의 풀이는 LIS의 상위호환이어야 한다.
  • 어차피 prefix를 찾는 것이므로 greedy하게 앞에서부터 되는 만큼 반복적으로 찾아주어도 된다.

그러니까 그냥 맨 앞에 있는 U들만 따로 떼고 생각해보자. 그러면 U의 개수만큼 LIS를 찾았을 때 모든 U를 처리할 수 있다. 문제는 그 다음에는 D만 따로 떼고 생각할 것이기 때문에 LDS를 찾아주어야 한다는 점이다. 그럼 LDS의 시작점이 가능한 큰 수여야 유리하지 않겠는가? LDS의 시작점이 LIS의 끝점과 같으므로 LIS의 끝점을 최대화해보자. LIS의 끝점을 오른쪽으로 옮겨도 답이 같다면, 즉 수가 증가한다면 이 수는 어차피 LDS의 길이를 증가시키는  데 사용하지 못하기 때문에 한 칸 옮기는 게 이득이다. 이를 반복한다. 

사실 이정도만 해도 formal한 proof와 크게 다르지 않는데, LIS가 증가할 때 increasing chain의 끝까지만 가는 것이 항상 최적인 이유는 어차피 LDS의 두 번째 원소 이후로는 다시 문제 상황이 원상복구 되므로 두 번째 원소의 후보들을 가능한 많이 만드는 것이 최적이기 때문이다. increasing chain보다 오른쪽으로 가서 더 큰 수를 찾아도, 어차피 거기에서 내려가는 것보다 increasing chain의 끝점에서 내려가는 쪽이 반드시 두 번째 원소가 빠르다(increasing chain의 종점은 decreasing). 

구현도 쉽다.

#include <bits/stdc++.h>

#define all(x) (x).begin(), (x).end()
#define endl "\n"
#define ends " "
#define fio()                     \
    ios_base::sync_with_stdio(0); \
    cin.tie(0)
using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

int n;
int a[303030];
string s;
vector<pii> b; // 0 = U, 1 = D

int f(char x) {
    if (x == 'U') return 0;
    else return 1;
}

pii go(int type, int cnt, int ipos) {
    // ipos부터 시작해서 type를 cnt번 사용할 수 있는가?

    function<bool(int, int)> comp;
    if (type == 0)
        comp = [](int a, int b) {return a < b; };
    else
        comp = [](int a, int b) {return a > b; };

    int last = -1;
    vector<int> lis;
    for (int i = ipos; i <= n; i++) {
        if (lis.empty() || comp(lis.back(), a[i]))
            lis.push_back(a[i]);
        else {
            auto it = lower_bound(all(lis), a[i], comp);
            *it = a[i];
        }
        if (lis.size() == cnt + 1) {
            last = i;
            break;
        }
    }
    if (last == -1) return make_pair(0, (int)lis.size() - 1);

    while (last + 1 <= n && comp(a[last], a[last + 1])) last++;
    return make_pair(1, last);
}

int main() {
    fio();

    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    cin >> s;

    for (int i = 0; i < s.length(); i++) {
        if (!i || b.back().first != f(s[i]))
            b.emplace_back(f(s[i]), 1);
        else
            b.back().second++;
    }

    int pos = 1, ans = 0;
    for (auto [type, cnt] : b) {
        auto [ret, val] = go(type, cnt, pos);
        if (!ret) {
            ans += val;
            break;
        }
        else {
            ans += cnt;
            pos = val;
        }
        if (pos == n) break;
    }
    cout << ans;

    return 0;
}