ho94949's profile image

ho94949

August 19, 2023 15:00

연산에 대한 O(n log n) / O(1) 구간 쿼리

range-query , data-structure

프로그래밍을 하면서 많이 맞닥뜨리는 연산중 하나는 “구간 쿼리”입니다. 결합법칙을 만족하는 어떤 종류의 연산이든, $O(N \log N)$ 전처리로 쿼리당 $O(1)$ 시간에 구간에 대한 쿼리를 답할 수 있는 구조를 설명하려고 합니다.

구간 쿼리

구간 쿼리는 다음과 같은 문제입니다.

  1. (전처리) 배열 $A = (A_0, A_1, \cdots, A_{N-1})$과 연산 $\circ$가 주어집니다.
  2. (쿼리) $1 \le l \le r \le N$이 주어지면, $A_l \circ A_{l+1} \circ \cdots \circ A_r$을 계산해야합니다.

여기서, 우리가 주목해볼 연산의 성질은 다음과 같습니다.

  • 결합법칙: 임의의 세 원소 $a, b, c$에 대해 $a \circ (b \circ c) = (a \circ b) \circ c$가 성립합니다.
  • 교환법칙: 임의의 두 원소 $a, b$에 대해 $a \circ b = b \circ a$가 성립합니다.
  • 멱등법칙: 임의의 원소 $a$에 대해 $a \circ a = a$가 성립합니다.
  • 항등원의 존재: 임의의 원소 $a$에 대해 $a \circ e = e \circ a = a$ 를 만족하는 고정된 $e$가 존재합니다.
  • 역원의 존재: 임의의 원소 $a$에 대해 $a \circ a^{-1} = a^{-1} \circ a = e$를 만족하는 $a^{-1}$가 존재합니다.

이 법칙들, 항등원 혹은 역원의 존재 여부는 연산마다 다릅니다.

  • $+$는 결합법칙, 교환법칙이 성립하고 항등원, 역원이 존재합니다.
  • $\times$는 결합법칙, 교환법칙이 성립하고 항등원이 존재합니다.
  • $-$는 어떤 법칙도 성립하지 않고, 항등원도 존재하지 않습니다.
  • 두 행렬의 곱셈은 결합법칙이 성립하고 항등원이 존재합니다.
  • 최솟값, 최댓값 연산은 결합법칙, 교환법칙, 멱등법칙이 성립하며, 경우에 따라 항등원이 존재합니다. (주로 $+\infty$, $-\infty$로 표기합니다)
  • $(L_1, R_1) \circ (L_2, R_2) = (L_1, R_2)$로 정의되는 연산이 있다고 합시다. 이 연산은 결합법칙, 멱등법칙이성립합니다.

이렇게 연산의 종류에 따라 다양한 성질이 존재하고, 이 성질에 맞는 특징을 이용해서 문제를 해결하면 됩니다.

  • 결합법칙이 성립하고 역원이 존재하는 연산에 대해서는 $O(N)$ 전처리로 쿼리당 $O(1)$ 시간에 문제의 답을 할 수 있습니다.
  • 결합법칙, 교환법칙, 멱등법칙이 성립하는 연산에 대해서는 $O(N \log N)$ 전처리로 쿼리당 $O(1)$ 시간에 문제의 답을 할 수 있습니다.

첫 번째 방식은 “누적합”, 두 번째 방식은 “희소 배열”이라는 이름으로 잘 알려져있습니다. 우리는 두 번째 방식에서 조건을 약화시킨, “결합법칙”이 성립하는 임의의 연산에 대해서 $O(N \log N)$ 전처리로 쿼리당 $O(1)$ 시간에 문제의 답을 하는 방법에 대해 알아볼 것입니다.

아이디어

기본적인 아이디어는, $A_l \circ A_{l+1} \circ \cdots \circ A_m$과 $A_{m+1} \circ A_{m+2} \circ \cdots \circ A_r$을 알고 있으면, 둘을 연산하는 것으로 $A_l \circ A_{l+1} \circ \cdots \circ A_r$을 계산할 수 있다는 것입니다. 전처리 과정에서 적당한 구간을 연산한 것들을 가지고 있으면, 문제를 해결할 수 있습니다. 편하게, $A_i \circ A_{i+1} \circ \cdots \circ A_j$를 $[i, j+1)$로 표현합시다.

우리가 쿼리당 정확히 한 번의 연산을 통해서 답을 도출해야한다고 하면 문제는 다음과 같이 바뀔 수 있습니다:

  • 모든 $0 \le l < r \le N$ 에 대해서 적당한 $m$이 존재해서 $[l, m)$과 $[m, r)$이 모두 존재하도록 하면 됩니다.

가장 자연스러운 생각은 $m = \frac{0+N}{2}$로 잡는 것입니다. 이 이후, $[0, m); [1, m); \cdots; [m-1, m)$ 을 계산하고 $[m, m+1); [m, m+2); \cdots; [m, r)$을, 총 $N$개의 값을 계산합니다. 이는 반복문으로 쉽게 계산할 수 있습니다. 이제 $0 \le l \le m \le r \le N$을 만족하는 $[l, r)$에 대한 답을 내놓을 수 있습니다.

이제 계산하지 못하는 값은, $0 \le l < r \le m-1$과 $m+1 \le l < r \le N$ 부분이 됩니다. 문제의 형태가 똑같기 때문에 이는 재귀적으로 전처리해줄 수 있습니다. 어떤 $[l, r)$이 들어오면 $l \le m \le r$을 만족하는 올바른 $m$을 찾을 수 있습니다.

전처리에 드는 연산 횟수 및 시간복잡도는 $T(N) = 2T(N/2) + O(N)$ 으로 $T(N) = O(N \log N)$이 되며, 전처리는 한 번의 연산으로 찾을 수 있습니다.

구현

구현을 좀 더 단순하게 해봅시다. 위와같이 구현할 경우, 전처리 과정에서 해당하는 $m$을 찾는 과정이 $O(\log N)$시간이 걸리므로 조금 더 세련된 방법을 생각해봅시다. 가장 쉬운 방법은 $N$을 임의의 길이로 늘려서 $2^k$꼴로 만드는 것입니다. 이 경우, 아래쪽에 재귀호출을 부르는 모든 배열의 길이도 $2$의 거듭제곱꼴이 됩니다. 이제 $[L, R)$구간이 나눠지는 형태를 살펴봅시다.

  • 처음에 $L = 0, R = 2^k$ 이고 구간의 길이 $R-L = 2^k$입니다. 여기서 주목할 점은, $L$과 $R$이 구간의 길이의 배수라는 점입니다.
  • $L, R$이 $R-L = 2^i$의 배수라고 합시다. $m = \frac{L+R}{2}$가 되고, $L, m, R$모두 $m-L = R-m = 2^{i-1}$의 배수가 됩니다.

즉, 구간이 나눠지면서 양쪽 끝 인덱스는 구간 길이에 해당하는 $2^i$의 배수 형태를 유지하게 됩니다. $[L, R)$ 구간은 $j = 0, \cdots, k; v = 0, \cdots, 2^{k-j}-1$에 대해서 $[v \times 2^j, (v+1) \times 2^j)$꼴이 됩니다.

$N=16$일 때 저장된 값

이제 쿼리로 주어진 $[l, r)$에 대해 적당한 $m$을 찾아봅시다. 쿼리로 주어진 $[l, r)$에 대해서 구간의 양 끝점인 $l$과 $r-1$이 모두 속하는 가장 긴 구간 $[L, R)$을 찾으면, $L \le l < \frac{L+R}{2} < r \le R$이 됩니다. 이는 $L, \frac{L+R}{2}, R$이 연속된 $2^{j-1}$꼴의 배수라는 점에서 착안하여, $l$과 $r$이 달라지는 가장 큰 자리수 비트를 이용해서 구하면 됩니다. 이는 $l \oplus r$의 가장 큰 비트를 구하면 됩니다.

구현체

아래는 구현체입니다.

  • sp[i][j]는 크기가 $2^i$인 구간에서 $j$번째 위치에 저장된 값을 의미합니다. 위 그림을 참고하면 좋습니다.
  • C++20의 bit_width에 의존합니다. bit_width(x)32-__builtin_clz(x)로 대체해도 똑같습니다.
  • Query 함수는 $A_l \circ \cdots \circ A_{r-1}$을 계산합니다.
#include <bit>
#include <cassert>
#include <vector>
using namespace std;
 
template <typename S, S (*op)(S, S)>
class RMQ
{
private:
    int N;
    vector<vector<S>> sp;
 
public:
    RMQ() : N(0){};
    explicit RMQ(const vector<S> &V) : N(int(V.size())), sp(bit_width((unsigned)max(N - 1, 1)), vector<S>(N))
    {
        for (int i = 0; i < (int)sp.size(); ++i)
        {
            for (int j = 0; j < N; j++)
                if (j & (1 << i))
                {
                    if ((j - 1) & (1 << i))
                        sp[i][j] = op(sp[i][j - 1], V[j]);
                    else
                        sp[i][j] = V[j];
                }
            for (int j = N - 1; j >= 0; --j)
                if (!(j & (1 << i)))
                {
                    if (j != N - 1 && !((j + 1) & (1 << i)))
                        sp[i][j] = op(V[j], sp[i][j + 1]);
                    else
                        sp[i][j] = V[j];
                }
        }
    }
 
    S Query(int L, int R)
    {
        assert(0 <= L && L < R && R <= N);
        --R;
        if (L == R)
            return sp[0][L];
        int idx = bit_width((unsigned)(L ^ R)) - 1;
        return op(sp[idx][L], sp[idx][R]);
    }
};

예시 문제

#include "secret.h"

RMQ<int, Secret> rmq;
void Init(int N, int A[]) { rmq = RMQ<int, Secret>(vector<int>(A, A + N)); }
int Query(int L, int R) { return rmq.Query(L, R + 1); }