ckw1140's profile image

ckw1140

March 8, 2019 12:00

meet in the middle

Meet in the middle

meet in the middle 은 절반 크기의 비슷한 문제를 두 번 해결한 결과를 통해 본 문제를 해결함으로서 문제 해결에 소요되는 시간 복잡도의 향상을 꾀하는 방법입니다.

저 말만 들으면 어떠한 장점이 있는지 감이 오지 않지만, 문제를 푸는 시간 복잡도가 exponential 한 경우라면 아래와 같이 개선이 된다는 것을 느낄수 있을 것입니다.

$2^n > 2 * 2^{n/2}$

구체적인 예제를 통해 설명해 보도록 하겠습니다.

BOJ 1208 부분집합의 합2

이 문제는 $N(N\leq 40)$개의 수로 이루어진 집합이 주어졌을 때, 원소의 합이 $S$가 되는 부분 집합의 개수를 구해야 하는 문제입니다.

가장 간단한 풀이로는 모든 부분 집합에 대하여 원소의 합이 몇인지 구하고 그 값이 $S$인 것의 개수를 세는 것입니다. 그러나 이 방법은 $N = 40$ 일 때는 $2^{40}$가지의 부분집합을 모두 확인해야 하므로 시간이 너무 많이 소요됩니다.

이러한 상황에서 meet in the middle 이 어떠한 방식으로 알고리즘의 시간 복잡도를 개선하는지 보도록 하겠습니다.

우선 집합을 $N/2$개씩 두 개의 집합으로 분리해 봅시다. 편의상, 각각을 $A, B$라고 부르겠습니다.

우선 $A$의 부분 집합의 원소의 합들을 모두 계산하여 각각의 합이 몇 번씩 나오는지를 미리 계산해 둡니다. 이는 나올 수 있는 합의 종류의 개수 정도의 크기를 갖는 배열을 잡으면 깔끔하게 구현할 수 있습니다. 이 배열을 $cnt$라고 하면, 이 과정을 수행한 결과 $cnt_i$에는 원소의 합이$i$인 부분 집합의 개수가 저장되어 있습니다.

이 부분의 시간 복잡도는 $O(2^{N/2})$ 입니다.

이제, $B$의 모든 부분집합의 합들을 계산합니다. 이 합들의 집합을 $sum$ 이라고 합시다. 이제 각각의 $sum$ 의 원소들에 대하여 $cnt_{S - sum_i}$ 의 합을 구해주면 이 값이 본 문제의 답이 됩니다.

이 부분의 시간 복잡도 또한 $O(2^{N/2})$ 입니다.

따라서, meet in the middle 의 방법으로 모든 문제를 $O(2^{N/2})$의 시간 복잡도로 해결하였습니다.

BOJ 4357 이산 로그 - meet in the middle 을 통해 시간 복잡도의 개선을 할 수 있는 대표적인 문제입니다. 실제로 암호 기법 중에 이산 로그를 계산하는 것의 어려움에 기반을 둔 암호들을 공격하는 방법으로 meet in the middle 방법이 종종 소개되곤 합니다.

이 문제는 $B, N, P$ ($P$는 소수)가 주어질 때, $B^L \equiv N (mod$ $P)$ 를 만족하는 $L$ 값을 구하는 문제입니다.

이 문제를 해결하는 naive한 방법은 $L$ 에 $0$ 부터 $P - 2$ 까지의 모든 값들을 대입해 보는 것입니다. 대입하여 계산한 결과가 위의 조건식을 만족한다면 우리가 원하는 답이 되는 것이지요. 이 방법의 시간 복잡도는 $O(P)$ 가 됩니다.

이제 meet in the middle 의 방법을 사용한 보다 효율적인 풀이를 알아보겠습니다.

우선 적당한 정수 $T$를 생각합시다. 그러면 $L = X * T + Y$ (단, $0 \leq X \leq P / T, 0 \leq Y < T$) 꼴로 표현된다는 점에 주목해봅시다.

첫번째로, 모든 $X$ 값에 대해 $B^{X*T} (mod$ $P)$ 의 값을 전부 계산한 뒤 적절한 자료구조(sorting 또는 균형 이진트리)에 기록해 둡니다. 가능한 $X$ 값의 경우의 수가 $P / T$가지 이므로 이 과정에는 $O(P / T)$ 의 시간이 소요됩니다.

두번째로, 모든 $Y$ 값에 대해 $B^Y (mod$ $P)$의 값을 계산합니다. 각각을 계산할 때마다 자료구조에 $B^{-Y} * N (mod$ $P)$의 값이 저장이 되어있는지 확인합니다. 만약 존재한다면 해당되는 $X, Y$에 대해 $L = X * T + Y$ 가 우리가 찾는 이산 로그의 값이 됩니다. 가능한 $Y$ 값의 경우의 수가 $T$가지 이므로 이과정에는 $O(T)$의 시간이 소요됩니다.

따라서 $T$의 값을 $\sqrt{P}$ 정도의 값으로 잡아준다면, 우리는 naive 한 방법보다 훨씬 개선된 $O(\sqrt{P})$의 시간 복잡도로 이산 로그 문제를 해결할 수 있게 됩니다.

BOJ 13169 Xor of Sums

이 문제는 주어진 multiset 의 부분집합의 합들의 Xor 값을 구하는 문제입니다.

문제를 해결하는 naive한 방법은 직접 모든 부분 집합에 대해 원소들의 합을 구한뒤 Xor 하는 방법입니다.

부분 집합의 개수는 $2^n$ 개 이고, 부분 집합의 원소들의 개수가 $n$개 이므로 이 방법의 시간 복잡도는 $O(n*2^n)$ 이고, $n \leq 30$ 이므로 제한 시간 내에 해결을 할 수 없게 됩니다.

이 문제 또한 meet in middle 을 통해 효율적으로 해결할 수 있습니다.

우선 multiset 을 $n/2$ 정도의 크기를 갖는 두개의 disjoint 한 multiset 으로 분리합니다. 그런 다음 각각의 multiset 에 대해서 모든 부분 집합의 합을 미리 계산하여 둡니다.

그 다음 결과 값의 모든 가능한 비트($lgMAXVAL$ 개의 비트)에 대해 독립적으로 고려하여 봅니다.

고려할 비트를 고정하였다면, 한 쪽의 multiset의 합들을 순회하면서 다른 multiset 의 합 중에서 몇개의 값들이 지금의 합과 합쳐졌을 때 현재 고려하고 있는 비트에 영향을 미칠지(=Xor 하여 비트를 flip할지) 를 빠르게 구하면 됩니다.

그 방법은 아래와 같습니다.

만일 우리가 다른 multiset 의 합들을 현재 고려하는 비트의 이하의 비트들만 남겨둔 배열을 정렬된 상태로 갖고 있다면 영향을 미치는 합들은 이 배열에서 연속된 구간의 모양을 이룰 것입니다.

따라서 이분 탐색을 통해 그러한 첫번째 수와 마지막 수의 위치를 구한다면 그 개수 또한 쉽게 구할 수 있습니다.

이제 우리는 최종 답을 저장할 변수를 0에서 시작하여 이 개수가 홀수라면 결과 값의 현재 고려하는 비트를 flip 하여 주고 아니면 그냥 넘어가기만 하면 됩니다.

이 과정에서 고려해야 하는 비트의 개수가 $lgMAXVAL$ 개 이며, 각 비트별로 $2^{n/2}$ 개의 합들을 순회하며 $2^{n/2}$ 개의 수들에 대해 이분 탐색을 진행합니다.

따라서 총 $O(n * 2^{n/2} * lgMAXVAL)$ 에 해결이 가능합니다.

아래는 문제를 위와 같은 방법으로 해결하는 코드입니다.

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;

int N;
int A[33];
vector<ll> S[44];

int count(int b, ll x) {
    return upper_bound(S[b].begin(), S[b].end(), x) - S[b].begin();
}
int count(int b, ll l, ll r) {
    return count(b, r) - count(b, l - 1);
}

void main2(int tc) {
    scanf("%d", &N);

    for(int i = 0; i < N; i++) {
        scanf("%d", &A[i]);
    }

    if(N == 1) {
        printf("%d\n", A[0]);
        return;
    }

    int n1 = N / 2;
    int n2 = N - n1;

    for(int i = 0; i < 35; i++) S[i].clear();
    for(int mask = 0; mask < (1 << n1); mask++) {
        ll sum = 0;
        for(int i = 0; i < n1; i++) if(mask & (1 << i)) sum += A[i];
        for(int i = 0; i < 35; i++) S[i].push_back(sum & ((1LL << (i + 1)) - 1));
    }

    for(int i = 0; i < 35; i++) {
        sort(S[i].begin(), S[i].end());
    }

    ll ans = 0;
    for(int mask = 0; mask < (1 << n2); mask++) {
        ll sum = 0;
        for(int i = 0; i < n2; i++) if(mask & (1 << i)) sum += A[n1 + i];
        for(int i = 0; i < 35; i++) {
            ll t = sum & ((1LL << i) - 1);

            if(count(i, (1LL << i) - t, (1LL << (i + 1)) - t - 1) % 2) {
                ans ^= (1LL << i);
            }
        }
    }
    printf("%lld\n", ans);
}

int TC = 1;
int main() {
    for(int i = 1; i <= TC; i++) main2(i);
}