ckw1140's profile image

ckw1140

July 19, 2019 12:00

알고리즘 문제 풀이4

알고리즘 문제 풀이 4

최근에 푼 재미있는 문제들을 포스팅 해보겠습니다.

[BOJ 2899] 구슬 없애기

구슬 $N$개가 일렬로 놓여있습니다. 같은 색의 구슬이 연속 K개 이상 놓여 있으면 그 구슬들을 없앨 수 있습니다.

구슬을 모두 없애고 싶지만 현재 상태로는 구슬을 모두 없애는 것이 불가능할 수도 있습니다.

현재 구슬의 사이에 마음대로 구슬을 추가할 수 있을 때, 구슬을 모두 없애기 위해 추가해야 하는 구슬의 최소 개수를 구하는 문제입니다.

복잡해 보이지만 다이나믹 프로그래밍을 사용하면 비교적 간단하게 해결 가능한 문제입니다.

먼저 다음과 같은 다이나믹 프로그래밍을 생각해 봅시다.

dp(l, r) := l부터 r까지의 구간에 있는 구슬을 모두 없애기 위해 필요한 최소 추가 수

이 상황에서 마지막에 없앤 구슬들이 어떤 것들이었을 까요?

그 구슬들을 $i_1, i_2, i_3, …, i_l$ 이었다고 해봅시다.

우선 순전히 이 구슬들을 없애기 위해 추가한 구슬을 몇개 일까요? 즉, 추가한 구슬 중에 이 구슬과 같이 사라지는 구슬은 몇개 일까요?

바로 $max(0, K - l)$ 개 일 것입니다.

또한 마지막에 저 구슬들을 없애는 상황에서 저 구슬들은 연속해 있었겠죠? 따라서 저 구슬들을 없애는 작업을 하기 전에 구간 $[l, i_1 - 1], [i_1 + 1, i_2 - 1], …, [i_l + 1, r]$ 들에 있는 구슬이 모두 사라진 상태일 것입니다.

여기서 원래 문제의 부분 문제의 구조가 보입니다.

모든 가능한 집합 $I =$ { $i_1, i_2, …, i_l, l \leq i_j \leq r$ }에 대해서

dp(l, r) = $Min_I(dp(l, i_1-1) + dp(i_1+1, i2-1) + … + dp(il+1, r) + max(0, K - l))$

가 될 것입니다.

이를 효과적으로 구현하기 위해 dp를 재정의 합시다.

dp(l, r, k) := l부터 r까지에 있는 구슬을 없앨 것이고, 현재 l은 앞에 같이 없앨 구슬이 l을 포함하여 k개가 존재한다.

그러면 transition에서 l과 같이 마지막에 없앨 다음 구슬을 골라주는 작업을 하면 됩니다.

더 자세한 설명보다는 다음 코드를 참고하시면 도움이 될것이라고 생각합니다.

시간 복잡도는 (dp의 상태공간 개수) * (transition 의 개수) 이므로 $O(N^3K)$ 가 됩니다.

#include<bits/stdc++.h>
using namespace std;

int N, K;
int A[111];

int cc[111][111][6];
int dp(int l, int r, int k) {
    if(l > r) return 0;
    int &ret = cc[l][r][k];
    if(ret != -1) return ret;

    ret = K - k + dp(l + 1, r, 1);

    for(int i = l + 1; i <= r; i++) {
        if(A[i] == A[l]) ret = min(ret, dp(i, r, min(K, k + 1)) + dp(l + 1, i - 1, 1));
    }
    return ret;
}

int main() {
    scanf("%d %d", &N, &K);

    for(int i = 0; i < N; i++) {
        scanf("%d", &A[i]);
    }

    memset(cc, -1, sizeof(cc));
    printf("%d", dp(0, N - 1, 1));
}

[BOJ 15480] LCA와 쿼리

트리가 하나 주어지고 그 트리 위에서 다음과 같은 쿼리를 여러번 수행해야 합니다.

r u v : 트리의 루트가 r이라고 했을 때, u, v의 lca 를 구하시오.

이 문제는 케이스를 잘 분류하면 쉽게 풀 수 있습니다.

우선 1번 정점을 루트라고 생각하고 트리를 널어 봅시다.

그 다음 쿼리의 케이스를 분류하면 다음의 경우들이 가능할 것입니다.

  1. u, v 가 모두 r의 자손인 경우 이 경우는 아주 간단합니다. 그냥 루트가 1일 때 u, v의 lca와 같습니다.
  2. 둘 중에 하나는 r의 자손이고 하나는 r의 자손이 아닐 때 이 경우도 간단합니다. 당연히 r이 lca 가 되겠죠.
  3. 둘 다 r의 자손이 아닌 경우 이 경우는 생각하기 까다로울 수 있지만 r에서 1로 가는 간선들을 쭉 늘어 놓고 u와 v가 어느 순간에 갈라져 나오는지를 생각하면 편합니다. 3-1. u, v가 서로 다른 깊이에서 갈라져 나온다. — 이 경우에는 더 깊은(r과 가까운)곳에서 갈라져 나온 정점이 갈라져 나온 그 지점이 바로 lca가 되겠습니다. 이 정점을 u였다고 하면 lca(u, r)로 갈라지는 정점을 찾을 수 있겠네요. 3-2. u, v가 서로 같은 깊이에서 다른 자식들로 갈라져 나온다. — 이 경우에는 서로 동시에 갈라진 그곳이 lca 이겠네요. 단순히 lca(u, v)로 구할 수 있습니다. 3-3. u, v가 서로 같은 깊이에서 같은 자식들로 갈라져 나온다. — 이 경우에는 당연히 lca(u, v)를 구하면 되겠네요.

따라서 다음과 같은 코드로 문제를 해결할 수 있습니다.

각 쿼리에서 lca를 구하는 연산을 많아야 3번 하므로 lca 를 log시간에 빠르게 구하는 알고리즘을 사용하면 총 시간 복잡도는 $O(NlogN + QlogN)$ 이 됩니다.

#include<bits/stdc++.h>
using namespace std;

const int MN = 100010;

int N, M;
vector<int> adj[MN];
int dep[MN], par[20][MN], tin[MN], tout[MN], timer;

void dfs(int u, int p) {
    tin[u] = timer++;

    par[0][u] = p;
    for(int i = 1; i < 20; i++) {
        int t = par[i - 1][u];
        if(t == -1) break;
        par[i][u] = par[i - 1][t];
    }
    for(int i = 0; i < adj[u].size(); i++) {
        int v = adj[u][i];
        if(v == p) continue;
        dep[v] = dep[u] + 1;
        dfs(v, u);
    }

    tout[u] = timer;
}

int lca(int a, int b) {
    if(dep[a] < dep[b]) swap(a, b);
    int diff = dep[a] - dep[b];
    for(int i = 0; i < 20; i++) if(diff & (1 << i)) a = par[i][a];
    if(a == b) return a;
    for(int i = 20; i--;) {
        if(par[i][a] != par[i][b]) {
            a = par[i][a];
            b = par[i][b];
        }
    }
    return par[0][a];
}

int main() {
    scanf("%d", &N);

    for(int i = 0; i < N - 1; i++) {
        int u, v; scanf("%d %d", &u, &v);
        u--; v--;

        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    memset(par, -1, sizeof(par));
    dfs(0, -1);

    scanf("%d", &M);

    for(int i = 0; i < M; i++) {
        int r, u, v; scanf("%d %d %d", &r, &u, &v);
        r--; u--; v--;

        if(tin[r] <= tin[u] && tin[u] < tout[r]) {
            if(tin[r] <= tin[v] && tin[v] < tout[r]) printf("%d\n", lca(u, v) + 1);
            else printf("%d\n", r + 1);
        }
        else {
            if(tin[r] <= tin[v] && tin[v] < tout[r]) printf("%d\n", r + 1);
            else {
                int x = lca(r, u);
                int y = lca(r, v);
                int z = lca(u, v);

                if(dep[x] != dep[y]) printf("%d\n", dep[x] > dep[y]? x + 1 : y + 1);
                else {
                    if(x == z) printf("%d\n", x + 1);
                    else printf("%d\n", z + 1);
                }
            }
        }
    }
}