연산에 대한 O(n log n) / O(1) 구간 쿼리
프로그래밍을 하면서 많이 맞닥뜨리는 연산중 하나는 “구간 쿼리”입니다. 결합법칙을 만족하는 어떤 종류의 연산이든, $O(N \log N)$ 전처리로 쿼리당 $O(1)$ 시간에 구간에 대한 쿼리를 답할 수 있는 구조를 설명하려고 합니다.
구간 쿼리
구간 쿼리는 다음과 같은 문제입니다.
- (전처리) 배열 $A = (A_0, A_1, \cdots, A_{N-1})$과 연산 $\circ$가 주어집니다.
- (쿼리) $1 \le l \le r \le N$이 주어지면, $A_l \circ A_{l+1} \circ \cdots \circ A_r$을 계산해야합니다.
여기서, 우리가 주목해볼 연산의 성질은 다음과 같습니다.
- 결합법칙: 임의의 세 원소 $a, b, c$에 대해 $a \circ (b \circ c) = (a \circ b) \circ c$가 성립합니다.
- 교환법칙: 임의의 두 원소 $a, b$에 대해 $a \circ b = b \circ a$가 성립합니다.
- 멱등법칙: 임의의 원소 $a$에 대해 $a \circ a = a$가 성립합니다.
- 항등원의 존재: 임의의 원소 $a$에 대해 $a \circ e = e \circ a = a$ 를 만족하는 고정된 $e$가 존재합니다.
- 역원의 존재: 임의의 원소 $a$에 대해 $a \circ a^{-1} = a^{-1} \circ a = e$를 만족하는 $a^{-1}$가 존재합니다.
이 법칙들, 항등원 혹은 역원의 존재 여부는 연산마다 다릅니다.
- $+$는 결합법칙, 교환법칙이 성립하고 항등원, 역원이 존재합니다.
- $\times$는 결합법칙, 교환법칙이 성립하고 항등원이 존재합니다.
- $-$는 어떤 법칙도 성립하지 않고, 항등원도 존재하지 않습니다.
- 두 행렬의 곱셈은 결합법칙이 성립하고 항등원이 존재합니다.
- 최솟값, 최댓값 연산은 결합법칙, 교환법칙, 멱등법칙이 성립하며, 경우에 따라 항등원이 존재합니다. (주로 $+\infty$, $-\infty$로 표기합니다)
- $(L_1, R_1) \circ (L_2, R_2) = (L_1, R_2)$로 정의되는 연산이 있다고 합시다. 이 연산은 결합법칙, 멱등법칙이성립합니다.
이렇게 연산의 종류에 따라 다양한 성질이 존재하고, 이 성질에 맞는 특징을 이용해서 문제를 해결하면 됩니다.
- 결합법칙이 성립하고 역원이 존재하는 연산에 대해서는 $O(N)$ 전처리로 쿼리당 $O(1)$ 시간에 문제의 답을 할 수 있습니다.
- 결합법칙, 교환법칙, 멱등법칙이 성립하는 연산에 대해서는 $O(N \log N)$ 전처리로 쿼리당 $O(1)$ 시간에 문제의 답을 할 수 있습니다.
첫 번째 방식은 “누적합”, 두 번째 방식은 “희소 배열”이라는 이름으로 잘 알려져있습니다. 우리는 두 번째 방식에서 조건을 약화시킨, “결합법칙”이 성립하는 임의의 연산에 대해서 $O(N \log N)$ 전처리로 쿼리당 $O(1)$ 시간에 문제의 답을 하는 방법에 대해 알아볼 것입니다.
아이디어
기본적인 아이디어는, $A_l \circ A_{l+1} \circ \cdots \circ A_m$과 $A_{m+1} \circ A_{m+2} \circ \cdots \circ A_r$을 알고 있으면, 둘을 연산하는 것으로 $A_l \circ A_{l+1} \circ \cdots \circ A_r$을 계산할 수 있다는 것입니다. 전처리 과정에서 적당한 구간을 연산한 것들을 가지고 있으면, 문제를 해결할 수 있습니다. 편하게, $A_i \circ A_{i+1} \circ \cdots \circ A_j$를 $[i, j+1)$로 표현합시다.
우리가 쿼리당 정확히 한 번의 연산을 통해서 답을 도출해야한다고 하면 문제는 다음과 같이 바뀔 수 있습니다:
- 모든 $0 \le l < r \le N$ 에 대해서 적당한 $m$이 존재해서 $[l, m)$과 $[m, r)$이 모두 존재하도록 하면 됩니다.
가장 자연스러운 생각은 $m = \frac{0+N}{2}$로 잡는 것입니다. 이 이후, $[0, m); [1, m); \cdots; [m-1, m)$ 을 계산하고 $[m, m+1); [m, m+2); \cdots; [m, r)$을, 총 $N$개의 값을 계산합니다. 이는 반복문으로 쉽게 계산할 수 있습니다. 이제 $0 \le l \le m \le r \le N$을 만족하는 $[l, r)$에 대한 답을 내놓을 수 있습니다.
이제 계산하지 못하는 값은, $0 \le l < r \le m-1$과 $m+1 \le l < r \le N$ 부분이 됩니다. 문제의 형태가 똑같기 때문에 이는 재귀적으로 전처리해줄 수 있습니다. 어떤 $[l, r)$이 들어오면 $l \le m \le r$을 만족하는 올바른 $m$을 찾을 수 있습니다.
전처리에 드는 연산 횟수 및 시간복잡도는 $T(N) = 2T(N/2) + O(N)$ 으로 $T(N) = O(N \log N)$이 되며, 전처리는 한 번의 연산으로 찾을 수 있습니다.
구현
구현을 좀 더 단순하게 해봅시다. 위와같이 구현할 경우, 전처리 과정에서 해당하는 $m$을 찾는 과정이 $O(\log N)$시간이 걸리므로 조금 더 세련된 방법을 생각해봅시다. 가장 쉬운 방법은 $N$을 임의의 길이로 늘려서 $2^k$꼴로 만드는 것입니다. 이 경우, 아래쪽에 재귀호출을 부르는 모든 배열의 길이도 $2$의 거듭제곱꼴이 됩니다. 이제 $[L, R)$구간이 나눠지는 형태를 살펴봅시다.
- 처음에 $L = 0, R = 2^k$ 이고 구간의 길이 $R-L = 2^k$입니다. 여기서 주목할 점은, $L$과 $R$이 구간의 길이의 배수라는 점입니다.
- $L, R$이 $R-L = 2^i$의 배수라고 합시다. $m = \frac{L+R}{2}$가 되고, $L, m, R$모두 $m-L = R-m = 2^{i-1}$의 배수가 됩니다.
즉, 구간이 나눠지면서 양쪽 끝 인덱스는 구간 길이에 해당하는 $2^i$의 배수 형태를 유지하게 됩니다. $[L, R)$ 구간은 $j = 0, \cdots, k; v = 0, \cdots, 2^{k-j}-1$에 대해서 $[v \times 2^j, (v+1) \times 2^j)$꼴이 됩니다.
$N=16$일 때 저장된 값 |
이제 쿼리로 주어진 $[l, r)$에 대해 적당한 $m$을 찾아봅시다. 쿼리로 주어진 $[l, r)$에 대해서 구간의 양 끝점인 $l$과 $r-1$이 모두 속하는 가장 긴 구간 $[L, R)$을 찾으면, $L \le l < \frac{L+R}{2} < r \le R$이 됩니다. 이는 $L, \frac{L+R}{2}, R$이 연속된 $2^{j-1}$꼴의 배수라는 점에서 착안하여, $l$과 $r$이 달라지는 가장 큰 자리수 비트를 이용해서 구하면 됩니다. 이는 $l \oplus r$의 가장 큰 비트를 구하면 됩니다.
구현체
아래는 구현체입니다.
sp[i][j]
는 크기가 $2^i$인 구간에서 $j$번째 위치에 저장된 값을 의미합니다. 위 그림을 참고하면 좋습니다.- C++20의
bit_width
에 의존합니다.bit_width(x)
를32-__builtin_clz(x)
로 대체해도 똑같습니다. Query
함수는 $A_l \circ \cdots \circ A_{r-1}$을 계산합니다.
#include <bit>
#include <cassert>
#include <vector>
using namespace std;
template <typename S, S (*op)(S, S)>
class RMQ
{
private:
int N;
vector<vector<S>> sp;
public:
RMQ() : N(0){};
explicit RMQ(const vector<S> &V) : N(int(V.size())), sp(bit_width((unsigned)max(N - 1, 1)), vector<S>(N))
{
for (int i = 0; i < (int)sp.size(); ++i)
{
for (int j = 0; j < N; j++)
if (j & (1 << i))
{
if ((j - 1) & (1 << i))
sp[i][j] = op(sp[i][j - 1], V[j]);
else
sp[i][j] = V[j];
}
for (int j = N - 1; j >= 0; --j)
if (!(j & (1 << i)))
{
if (j != N - 1 && !((j + 1) & (1 << i)))
sp[i][j] = op(V[j], sp[i][j + 1]);
else
sp[i][j] = V[j];
}
}
}
S Query(int L, int R)
{
assert(0 <= L && L < R && R <= N);
--R;
if (L == R)
return sp[0][L];
int idx = bit_width((unsigned)(L ^ R)) - 1;
return op(sp[idx][L], sp[idx][R]);
}
};
예시 문제
-
예시 문제: JOI Open Contest 2014 6번
-
예시 코드 (위의
RMQ
구현체를 포함하고, 이와 같이 작성합니다)
#include "secret.h"
RMQ<int, Secret> rmq;
void Init(int N, int A[]) { rmq = RMQ<int, Secret>(vector<int>(A, A + N)); }
int Query(int L, int R) { return rmq.Query(L, R + 1); }