read, write, mmap을 이용한 Fast I/O 구현 (2)
이번 글에서는 이전 글인 read, write, mmap을 이용한 Fast I/O 구현 (1)에 이어 Bit-Twiddling Hack을 적용한 더 빠른 입력 기법과 Lookup Table 및 SIMD를 활용한 출력 기법을 살펴보겠습니다.
1. Faster Input Implementation
read
또는 mmap
을 이용한 빠른 입력 방식은 입력 버퍼를 두고 버퍼 단위로 한 번에 입력을 읽어오며 입력 속도를 향상시킵니다. 대부분의 문제에서는 이 정도 최적화만으로 충분하지만, 한 번에 여러 문자를 처리할 수 있다면 입력 속도를 더욱 높일 수 있습니다.
1.1 8-Byte Parsing
앞서 소개한 Fast I/O는 입력 버퍼를 통해 $1$바이트씩 문자를 읽는 방식이었습니다. Bit-Twiddling 기법을 활용하면, $8$바이트 단위로 입력을 가져와 $64$비트 정수 자료형의 비트 연산으로 한 번에 여러 자리 숫자를 처리할 수 있습니다.
auto read_int = [&] {
int ret = 0, flag = 0;
if (*pr == '-') flag = 1, pr++;
u64 x; memcpy(&x, pr, 8);
x ^= 0x3030303030303030;
if (!(x & 0xf0f0f0f0f0f0f0f0)) {
x = ((x * 10) + (x >> 8)) & 0x00ff00ff00ff00ff;
x = ((x * 100) + (x >> 16)) & 0x0000ffff0000ffff;
x = ((x * 10000) + (x >> 32)) & 0x00000000ffffffff;
ret = x;
pr += 8;
}
while (*pr & 16) ret = 10 * ret + (*pr++ & 15);
pr++;
if (flag) ret = -ret;
return ret;
};
코드에서 memcpy(&x, pr, 8)
은 pr[0]
, $\cdots$, pr[7]
을 $64$비트 정수 자료형 x
에 저장합니다.
ASCII 코드의 성질을 이용하면 x
를 구성하는 $8$개의 문자가 모두 숫자인지 여부를 알아낼 수 있습니다.
숫자 '0'
, $\cdots$, '9'
는 ASCII 코드에서 상위 $4$비트가 항상 0011____
로 고정되어 있습니다. 따라서 입력 문자를 0x30 = 0b00110000
과 xor 연산하면 각 문자의 상위 $4$비트가 0000____
이 됩니다. 즉, '0'
는 0x00
, '1'
는 0x01
, $\cdots$, '9'
는 0x09
로 변환됩니다.
반면, 공백(0x20
), 줄바꿈(0x0A
), 하이픈(0x2D
) 등의 숫자가 아닌 문자는 xor 결과에서 상위 $4$개 비트 중 1
인 비트가 남아있습니다. 따라서 xor 결과를 0xf0 = 0b11110000
와 bitwise and 연산을 한 뒤 결과가 $0$인지 확인하면 문자가 숫자인지 확인할 수 있습니다. 코드에서는 x & 0xf0f0f0f0f0f0f0f0
의 결과가 $0$인지 확인하며 $8$개 문자 중 공백 또는 줄바꿈이 존재하는지 확인합니다.
이제 $8$개 문자가 모두 숫자일 때, 전체 숫자를 이어붙인 값을 빠르게 구하는 과정을 알아보겠습니다.
x = ((x * 10) + (x >> 8)) & 0x00ff00ff00ff00ff;
x = ((x * 100) + (x >> 16)) & 0x0000ffff0000ffff;
x = ((x * 10000) + (x >> 32)) & 0x00000000ffffffff;
초기 상태에 x
에는 $8$개의 숫자가 차례로 저장되어 있습니다. 이를 $2$개씩 묶어서 이어붙이는 과정을 $3$번 반복하면 전체 숫자를 이어붙인 값을 구할 수 있습니다.
첫 번째 줄에서는 $1$바이트씩 숫자를 이어붙입니다. 각 $1$바이트에 저장된 값을 순서대로 $d_0, d_1, \cdots, d_7$이라 합시다. 인접한 두 값 $(d_{2i}, d_{2i+1})$은 (x * 10) + (x >> 8)
에서 $(10d_{2i} + d_{2i+1}, *)$이 되고, 0x00ff
와 bitwise and 연산을 통해 $(10d_{2i} + d_{2i+1}, 0)$이 됩니다. 따라서 첫 번째 줄이 실행되면 $x$의 각 $2$바이트에 두 자리 숫자가 저장됩니다.
비슷하게 두 번째 줄에서는 $2$바이트씩 숫자를 이어붙입니다. 인접한 두 값 $(d_{2i}, d_{2i+1})$은 (x * 100) + (x >> 16)
에서 $(100d_{2i} + d_{2i+1}, 0)$이 되고, 0x0000ffff
와 bitwise and 연산을 통해 $(100d_{2i} + d_{2i + 1}, 0)$이 됩니다.
마지막으로 세 번째 줄에서는 $4$바이트씩 숫자를 이어붙입니다. (x * 10000) + (x >> 32)
와 0x00000000ffffffff
를 이용하면 $8$바이트에 여덟 자리 숫자가 저장됩니다. 이 값을 정답에 기록한 뒤 pr
를 $8$만큼 이동시키면 한 번에 $8$글자를 처리할 수 있습니다.
이후 아직 처리되지 않은 글자를 한 글자씩 순차적으로 읽으며 처리하면 전체 입력을 처리할 수 있습니다.
while (*pr & 16) ret = 10 * ret + (*pr++ & 15);
사용 예시는 다음과 같습니다. (코드)
auto read_int = [&] {
int ret = 0, flag = 0;
if (*pr == '-') flag = 1, pr++;
u32 x; memcpy(&x, pr, 4);
x ^= 0x30303030;
if (!(x & 0xf0f0f0f0)) {
x = ((x * 10) + (x >> 8)) & 0x00ff00ff;
x = ((x * 100) + (x >> 16)) & 0x0000ffff;
ret = x;
pr += 4;
}
while (*pr & 16) ret = 10 * ret + (*pr++ & 15);
pr++;
if (flag) ret = -ret;
return ret;
};
비슷하게, $8$바이트 단위 대신 $4$바이트 단위로 숫자를 처리할 수도 있습니다. 이는 입력되는 숫자가 $8$자리 미만인 경우에 유용합니다.
사용 예시는 다음과 같습니다. (코드)
1.2 8-4-2-1 Byte Parsing
auto read_int = [&] {
int ret = 0, flag = 0;
if (*pr == '-') flag = 1, pr++;
{
u64 x; memcpy(&x, pr, 8);
x ^= 0x3030303030303030;
if (!(x & 0xf0f0f0f0f0f0f0f0)) {
x = ((x * 10) + (x >> 8)) & 0x00ff00ff00ff00ff;
x = ((x * 100) + (x >> 16)) & 0x0000ffff0000ffff;
x = ((x * 10000) + (x >> 32)) & 0x00000000ffffffff;
ret = x;
pr += 8;
}
}
{
u32 x; memcpy(&x, pr, 4);
x ^= 0x30303030;
if (!(x & 0xf0f0f0f0)) {
x = ((x * 10) + (x >> 8)) & 0x00ff00ff;
x = ((x * 100) + (x >> 16)) & 0x0000ffff;
ret = 10000 * ret + x;
pr += 4;
}
}
{
u16 x; memcpy(&x, pr, 2);
x ^= 0x3030;
if (!(x & 0xf0f0)) {
x = ((x * 10) + (x >> 8)) & 0x00ff;
ret = 100 * ret + x;
pr += 2;
}
}
if (*pr & 16) ret = 10 * ret + (*pr++ & 15);
pr++;
if (flag) ret = -ret;
return ret;
};
$8$바이트 단위로 한 번에 여덟 자리 숫자를 처리하는 방법을 응용하면 $8, 4, 2, 1$바이트 단위로 한 번에 여덟 자리, 네 자리, 두 자리, 한 자리 숫자를 처리하며 $4$번의 과정으로 $15$자리 이하의 정수 자료형을 입력받을 수 있습니다.
사용 예시는 다음과 같습니다. (코드)
1.3 8-Byte Parsing with std::countr_zero
auto read_int = [&] {
u64 x; memcpy(&x, pr, 8);
x ^= 0x3030303030303030;
int t = std::countr_zero(x & 0xf0f0f0f0f0f0f0f0) >> 3;
x <<= 64 - (t << 3);
x = (x * 10 + (x >> 8)) & 0x00ff00ff00ff00ff;
x = (x * 100 + (x >> 16)) & 0x0000ffff0000ffff;
x = (x * 10000 + (x >> 32)) & 0x00000000ffffffff;
pr += t + 1;
return x;
};
입력되는 정수가 $8$자리 이하라면 std::countr_zero
를 이용해 정수를 조건 분기 없이 파싱할 수 있습니다.
pr[0]
, $\cdots$, pr[7]
, pr[8]
중 숫자가 아닌 문자가 처음으로 등장하는 위치는 x & 0xf0f0f0f0f0f0f0f0
에 std::countr_zero
를 적용해 알아낼 수 있습니다. std::countr_zero
는 bit
헤더에 정의된 함수로 가장 작은 비트부터 연속한 0
의 개수를 반환합니다. 이를 $8$로 나눈 값을 t
라 하면 pr[t]
는 처음으로 숫자가 아닌 문자가 등장하는 위치가 됩니다.
x <<= 64 - (t << 3)
에서는 x
에서 pr[t]
이후 부분을 없애고 가장 작은 자릿수의 숫자가 최상위 바이트에 위치하도록 합니다. 이후 $8$바이트 단위로 입력을 가져와 비트 연산으로 숫자를 이어붙이는 방식을 사용하면 $8$자리 이하 정수를 조건 분기 없이 파싱할 수 있습니다.
사용 예시는 다음과 같습니다. (코드)
auto read_int = [&] {
u32 x; memcpy(&x, pr, 4);
x ^= 0x30303030;
int t = std::countr_zero(x & 0xf0f0f0f0) >> 3;
x <<= 32 - (t << 3);
x = (x * 10 + (x >> 8)) & 0x00ff00ff;
x = (x * 100 + (x >> 16)) & 0x0000ffff;
pr += t + 1;
return x;
};
비슷하게, 입력되는 정수가 $4$자리 이하라면 $32$비트 정수 자료형을 이용해 정수를 조건 분기 없이 파싱할 수 있습니다.
사용 예시는 다음과 같습니다. (코드)
2. Faster Output Implementation
이전 글에서는 write
함수를 활용한 빠른 출력 구현 방법에 대해 살펴보았습니다. 빠른 출력은 x % 10
연산을 통해 가장 작은 자릿수부터 숫자를 버퍼에 기록하는 방식으로 작동하므로, 출력 순서가 반대입니다. 이를 해결하기 위해 빠른 출력의 구현에서는 숫자를 임시 버퍼에 순차적으로 기록한 후 이를 역순으로 출력 버퍼에 복사하거나, 미리 x
의 자릿수 sz
를 계산한 후 역순으로 직접 출력 버퍼에 기록하는 방식을 사용합니다.
구현에서 주요 오버헤드는 두 가지입니다. 하나는 x % 10
및 x / 10
에서 발생하는 나눗셈 연산이며, 다른 하나는 숫자를 한 자리씩 기록하는 것입니다. 따라서 나눗셈 연산의 빈도를 줄이고, 여러 자릿수를 한 번에 출력 버퍼에 기록할 수 있다면 출력 속도를 더욱 향상시킬 수 있습니다.
2.1 Digit Counting with Bit-Twiddling Hack
constexpr int p10[] = {
0, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000
};
auto count_digit = [](u32 n) {
int t = std::bit_width(n) * 1233 >> 12;
return t - (n < p10[t]) + 1;
};
auto write_int = [&](int x) {
if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
if (x < 0) *pw++ = '-', x = -x;
int sz = count_digit(x);
for (int i = sz - 1; i >= 0; i--) pw[i] = x % 10 | 48, x /= 10;
pw += sz;
};
x
의 자릿수를 구하는 고속 출력 기법에서 x
의 자릿수 sz
는 $\lfloor\log_{10} x\rfloor$번의 나눗셈 연산 또는 $3$번의 조건 분기 대신 $1$번의 std::bit_width
호출을 이용해 구할 수 있습니다.
std::bit_width
는 bit
헤더에 정의된 함수로 $\lceil\log_2(n+1)\rceil$을 반환합니다. 여기에 $\log_{10} 2 = 0.301029995664\dots \approx \frac{1233}{4096}$를 곱하면 std::bit_width(n) * 1233 >> 12
로 $\lfloor\log_{10} n\rfloor$을 근사할 수 있습니다. (참고)
오차 범위는 다음과 같이 계산됩니다.
\[\begin{align*} x & =\lceil\log_2(n+1)\rceil \\ c & =\frac{1233}{4096}=0.301025390625 \\ k & =\log_{10}2\approx 0.301029995664 \\ e & =\lfloor xc\rfloor - \lfloor\log_{10}n\rfloor \end{align*}\]$1 \leq x \leq 680$이라면 $\lfloor xc\rfloor = \lfloor xk\rfloor$이 성립합니다.
$n \geq 1$이라면 $\lceil\log_2(n+1)\rceil = \lfloor\log_2n\rfloor + 1$에서 다음이 성립합니다.
\[\begin{align*} & & 2^{x-1} \leq n \leq 2^x \\ \Rightarrow & & (x-1)k \leq \log_{10}n \leq xk \\ \Rightarrow & & \lfloor(x-1)k\rfloor \leq \lfloor\log_{10}n\rfloor \leq \lfloor xk\rfloor \\ \Rightarrow & & 0 \leq \lfloor xk\rfloor - \lfloor\log_{10}n\rfloor \leq 1 \end{align*}\]따라서 $1 \leq n < 2^{680}$ 범위에서 $0 \leq e \leq 1$이 성립하고, 모든 $128$ 비트 이하 unsigned integer에 대해 std::bit_width(n) * 1233 >> 12
는 $\lfloor\log_{10} n\rfloor$ 또는 $\lfloor\log_{10} n\rfloor + 1$입니다. 이때 로그 값이 $1$ 크게 구해진 경우는 $10^0, 10^1, \cdots$를 미리 구해두면 $1$번의 비교 연산으로 보정할 수 있습니다.
이상의 과정을 코드로 옮기면 std::bit_width
을 이용해 count_digit
함수를 구현할 수 있고, write_int
함수는 이전 글과 동일한 방법으로 구현할 수 있습니다.
사용 예시는 다음과 같습니다. (코드)
2.2 2-Digit Lookup Table
constexpr char lut[200] = {
'0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9',
'1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9',
'2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9',
'3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9',
'4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9',
'5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9',
'6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9',
'7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9',
'8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9',
'9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9'
};
auto write_int = [&](int x) {
if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
if (x < 0) *pw++ = '-', x = -x;
char t[10], *pt = t;
while (x >= 100) {
int i = x % 100 << 1;
x /= 100;
*pt++ = lut[i + 1];
*pt++ = lut[i];
}
if (x < 10) {
*pt++ = x + '0';
}
else {
int i = x << 1;
*pt++ = lut[i + 1];
*pt++ = lut[i];
}
do *pw++ = *--pt; while (pt != t);
};
$0 \leq i < 100$ 범위의 $i$에 대해 i / 10
과 i % 10
를 lut
배열에 전처리해두면 x
를 $100$으로 나누며 두 자리씩 출력할 수 있습니다. 이는 나눗셈 횟수를 절반으로 줄일 수 있어 x
를 $10$으로 나누는 방식보다 더 빠르게 동작합니다.
이러한 최적화 기법에서 사용되는 전처리 배열을 lookup table이라 합니다.
사용 예시는 다음과 같습니다. (코드)
constexpr auto lut = [] {
std::array<char, 200> res;
for (int i = 0; i < 100; i++) {
res[2 * i] = i / 10 | 48;
res[2 * i + 1] = i % 10 | 48;
}
return res;
}();
constexpr 키워드를 이용하면 컴파일 타임에 lut
를 계산할 수 있어서 전역에 상수 값을 직접 넣어주지 않아도 됩니다.
사용 예시는 다음과 같습니다. (코드)
2.3 4-Digit Lookup Table
constexpr auto lut = [] {
std::array<std::array<char, 4>, 10000> res;
for (int i = 0; i < 10000; i++) {
res[i][0] = i / 1000 | 48;
res[i][1] = i / 100 % 10 | 48;
res[i][2] = i / 10 % 10 | 48;
res[i][3] = i % 10 | 48;
}
return res;
}();
auto write_4digit_trimmed = [&](int x) {
if (x > 999) {
memcpy(pw, &lut[x], 4);
pw += 4;
}
else if (x > 99) {
memcpy(pw, &lut[x * 10], 3);
pw += 3;
}
else if (x > 9) {
memcpy(pw, &lut[x * 100], 2);
pw += 2;
}
else {
memcpy(pw, &lut[x * 1000], 1);
pw += 1;
}
};
auto write_4digit = [&](int x) {
memcpy(pw, &lut[x], 4);
pw += 4;
};
auto write_int = [&](int x) {
if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
if (x < 0) *pw++ = '-', x = -x;
if (x > 99999999) {
write_4digit_trimmed(x / 10000 / 10000);
write_4digit(x / 10000 % 10000);
write_4digit(x % 10000);
}
else if (x > 9999) {
write_4digit_trimmed(x / 10000);
write_4digit(x % 10000);
}
else {
write_4digit_trimmed(x);
}
};
$0 \leq i < 10\,000$ 범위의 $i$에 대해 $i$의 $j$번째 자릿수를 lut[i][4 - j]
에 저장해 lookup table을 구성하면, 숫자를 문자열로 변환할 때 한 번에 $4$자리씩 버퍼를 채울 수 있습니다.
write_4digit_trimmed
와 write_4digit
함수는 $4$자리 이하 또는 정확히 $4$자리 숫자를 memcpy
로 한 번에 버퍼에 복사하는 함수입니다. write_int
함수는 x
의 자릿수에 대해 조건 분기를 하며 해당 함수를 이용해 x
를$4$자리씩 끊어서 출력합니다.
이렇게 $4$자리씩 memcpy
를 이용해 lookup table의 값을 버퍼에 복사하면 한 번에 여러 개의 값을 처리할 수 있어 각 자릿수를 하나씩 처리하는 방식보다 효율적입니다.
사용 예시는 다음과 같습니다. (코드)
constexpr auto lut_digit_trimmed = [] {
std::array<std::array<char, 4>, 10000> res{};
for (int i = 0; i < 10; i++) {
res[i][0] = i | 48;
}
for (int i = 10; i < 100; i++) {
res[i][0] = i / 10 | 48;
res[i][1] = i % 10 | 48;
}
for (int i = 100; i < 1000; i++) {
res[i][0] = i / 100 | 48;
res[i][1] = i / 10 % 10 | 48;
res[i][2] = i % 10 | 48;
}
for (int i = 1000; i < 10000; i++) {
res[i][0] = i / 1000 | 48;
res[i][1] = i / 100 % 10 | 48;
res[i][2] = i / 10 % 10 | 48;
res[i][3] = i % 10 | 48;
}
return res;
}();
constexpr auto lut_digit = [] {
std::array<std::array<char, 4>, 10000> res;
for (int i = 0; i < 10000; i++) {
res[i][0] = i / 1000 | 48;
res[i][1] = i / 100 % 10 | 48;
res[i][2] = i / 10 % 10 | 48;
res[i][3] = i % 10 | 48;
}
return res;
}();
constexpr auto lut_sz = [] {
std::array<char, 10000> res;
for (int i = 0; i < 10; i++) res[i] = 1;
for (int i = 10; i < 100; i++) res[i] = 2;
for (int i = 100; i < 1000; i++) res[i] = 3;
for (int i = 1000; i < 10000; i++) res[i] = 4;
return res;
}();
auto write_4digit_trimmed = [&](int x) {
memcpy(pw, &lut_digit_trimmed[x], 4);
pw += lut_sz[x];
};
auto write_4digit = [&](int x) {
memcpy(pw, &lut_digit[x], 4);
pw += 4;
};
auto write_int = [&](int x) {
if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
if (x < 0) *pw++ = '-', x = -x;
if (x > 99999999) {
write_4digit_trimmed(x / 10000 / 10000);
write_4digit(x / 10000 % 10000);
write_4digit(x % 10000);
}
else if (x > 9999) {
write_4digit_trimmed(x / 10000);
write_4digit(x % 10000);
}
else {
write_4digit_trimmed(x);
}
};
$4$자리 미만 수의 leading zero를 제외한 표현과 자릿수에 대한 lookup table을 추가로 구성하면 write_4digit_trimmed
의 구현에서 조건 분기를 없앨 수 있습니다.
사용 예시는 다음과 같습니다. (코드)
2.4 SIMD(SSE2) Vectorized Output
alignas(16) constexpr u32 kDiv10000Vec[4] = { 3518437209, 3518437209, 3518437209, 3518437209 };
alignas(16) constexpr u32 k10000Vec[4] = { 10000, 10000, 10000, 10000 };
alignas(16) constexpr u16 kDivPwrVec[8] = { 8389, 5243, 13108, 32768, 8389, 5243, 13108, 32768 };
alignas(16) constexpr u16 k10Vec[8] = { 10, 10, 10, 10, 10, 10, 10, 10 };
alignas(16) constexpr u16 kShiftPwrVec[8] = { 128, 2048, 8192, 32768, 128, 2048, 8192, 32768 };
alignas(16) constexpr char kAsciiZero[16] = { 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48 };
inline __m128i convert8digits(u32 x) {
__m128i vv = _mm_cvtsi32_si128(x);
__m128i abcd = _mm_srli_epi64(_mm_mul_epu32(vv, ((__m128i*)(kDiv10000Vec))[0]), 45);
__m128i efgh = _mm_sub_epi32(vv, _mm_mul_epu32(abcd, ((__m128i*)(k10000Vec))[0]));
__m128i v1 = _mm_unpacklo_epi16(abcd, efgh);
__m128i v1x4 = _mm_slli_epi64(v1, 2);
__m128i v2a = _mm_unpacklo_epi16(v1x4, v1x4);
__m128i v2 = _mm_unpacklo_epi32(v2a, v2a);
__m128i v3 = _mm_mulhi_epu16(v2, ((__m128i*)(kDivPwrVec))[0]);
__m128i v4 = _mm_mulhi_epu16(v3, ((__m128i*)(kShiftPwrVec))[0]);
__m128i v5 = _mm_mullo_epi16(v4, ((__m128i*)(k10Vec))[0]);
__m128i v6 = _mm_slli_epi64(v5, 16);
return _mm_sub_epi16(v4, v6);
}
auto write_int = [&](int x) {
if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
if (x < 0) *pw++ = '-', x = -x;
if (x > 99999999) {
write_4digit_trimmed(x / 100000000);
x %= 100000000;
__m128i digits = convert8digits(x);
__m128i ascii = _mm_add_epi8(
_mm_packus_epi16(_mm_setzero_si128(), digits),
((__m128i*)(kAsciiZero))[0]
);
ascii = _mm_srli_si128(ascii, 8);
_mm_storel_epi64((__m128i*)(pw), ascii);
pw += 8;
}
else if (x > 9999) {
write_4digit_trimmed(x / 10000);
write_4digit(x % 10000);
}
else {
write_4digit_trimmed(x);
}
};
SIMD(Single Instruction Multiple Data)를 이용하면 $32$비트 정수를 한 번에 $8$자리씩 변환해 버퍼에 기록할 수 있습니다. (참고)
구현은 4-digit lookup table과 유사합니다. 만약 출력하는 수가 $8$자리 이하라면 lut
와 memcpy
를 이용해 출력을 처리할 수 있고, $9$자리 이상이라면 앞부분을 lut
를 이용해 출력한 뒤 남은 $8$자리를 convert8digits
함수를 이용해 SIMD로 처리합니다.
convert8digits
함수는 기존의 Bit-Twiddling Hack 기반 고속 입력 기법에서 $8$자리 숫자를 처리하던 과정을 역으로 수행하여 구현할 수 있습니다. 이 과정에서는 각 자릿수에 대한 산술 연산이 필요하므로, 단순한 $64$비트 정수 자료형만으로는 충분하지 않으며, 대신 $128$비트 벡터 자료형인 __m128i
와 SSE2 명령어 집합을 활용하여 병렬적으로 연산을 수행합니다. 나눗셈 연산은 성능 향상을 위해 Barrett Reduction 기법을 사용하여 곱셈과 bitwise shift 연산으로 대체됩니다.
이 방식은 $9$자리 이상의 정수를 $8$자리 단위로 분할하여 한 번에 출력할 수 있어, 이론적으로는 성능 향상을 기대할 수 있습니다. 그러나 실제 테스트 결과, SIMD를 이용해 $8$자리를 병렬로 처리하는 방식보다 4-digit lookup table을 이용해 $4$자리씩 처리하는 방식이 더 좋은 성능을 보였습니다. 이는 룩업 테이블 기반 방식이 memcpy
와 같은 단순한 연산만으로 구현되기 때문인 것으로 생각됩니다.
사용 예시는 다음과 같습니다. (코드)
3. Summary
이번 글에서는 저수준 입출력 기법을 한층 더 최적화하기 위해 Bit-Twiddling Hack, Lookup Table, 그리고 SIMD(Vectorized Instruction) 기술을 활용한 입력 및 출력 기법을 다루었습니다. 입력 최적화에서는 $8$바이트 정수를 묶어서 처리하는 방식, 그리고 std::countr_zero
와 같은 비트 연산 기반 기술을 적용하여 조건 분기 없이 정수를 파싱하는 방법을 살펴보았습니다. 이는 한 번에 여러 값을 처리하며 입력 속도를 향상합니다. 출력 최적화에서는 나눗셈 연산의 횟수를 줄이고 여러 자릿수를 한 번에 출력하기 위해 2-digit 및 4-digit lookup table을 이용하였으며, SSE2 명령어 집합을 활용한 $8$바이트 출력 기법을 살펴보았습니다.
지금까지 살펴본 구현 방식의 장단점, 성능을 정리하면 다음과 같습니다.
[1] Fast Input Methods: BOJ 11003
빠른 입력기법 | 장점 | 단점 | 실행 시간 | 코드 |
---|---|---|---|---|
scanf |
C 표준 입력 | 느림 | 1932ms | link |
cin |
C++ 표준 입력 | 느림 | 1412ms | link |
read |
메모리를 적게 사용 | 약간 느림 | 876ms | link |
mmap |
빠름 | 메모리를 많이 사용 | 848ms | link |
8-Byte Parsing | 매우 빠름 | 메모리를 많이 사용 | 812ms | link |
8-4-2-1 Byte Parsing | 매우 빠름 | 메모리를 많이 사용 | 808ms | link |
fastest | - | - | 152ms | link |
결과를 보면 read
, mmap
를 이용한 기본적인 빠른 입력 구현도 실행 시간을 절반 이상 단축할 정도로 충분히 성능이 좋은 걸 알 수 있습니다. 대부분의 경우는 이정도 구현으로 충분한 성능을 얻을 수 있을 것입니다.
가장 좋은 성능을 보인 방법은 Bit-Twiddling Hack을 이용한 8-Byte Parsing입니다. 추가적인 최적화가 필요하다면 해당 방법을 이용하면 좋을 것입니다.
[2] Fast Conditioned Input Methods: BOJ 15552
빠른 입력기법 | 장점 | 단점 | 실행 시간 | 코드 |
---|---|---|---|---|
scanf |
C 표준 입력 | 느림 | 288ms | link |
cin |
C++ 표준 입력 | 느림 | 220ms | link |
read |
메모리를 적게 사용 | 약간 느림 | 108ms | link |
read ($x \geq 0$, single space) |
메모리를 적게 사용 | 약간 느림 | 104ms | link |
mmap |
빠름 | 메모리를 많이 사용 | 104ms | link |
mmap ($x \geq 0$, single space) |
빠름 | 메모리를 많이 사용 | 100ms | link |
mmap with bitwise operation ($x \geq 0$, single space) |
빠름 | 메모리를 많이 사용 | 96ms | link |
8-Byte Parsing with std::countr_zero ($x \geq 0$, single space) |
매우 빠름 | 메모리를 많이 사용, $x$가 $8$자리 이하여야 함 | 92ms | link |
4-Byte Parsing with std::countr_zero ($x \geq 0$, single space) |
매우 빠름 | 메모리를 많이 사용, $x$가 $4$자리 이하여야 함 | 92ms | link |
fastest | - | - | 12ms | link |
입력 파일이 특정 조건을 만족한다고 가정하면 빠른 입력 기법을 더욱 최적화할 수 있습니다.
테스트에 사용한 문제는 입력 범위가 $[1, 10^3]$이기 때문에 std::countr_zero
를 이용한 4-Byte Parsing이 가장 좋은 성능을 보였습니다.
[3] Fast Output Methods: BOJ 11003
빠른 출력기법 | 장점 | 단점 | 실행 시간 | 코드 |
---|---|---|---|---|
printf |
C 표준 출력 | 느림 | 1932ms | link |
cout |
C++ 표준 출력 | 느림 | 1584ms | link |
write |
빠름 | - | 1248ms | link |
Digit Counting (naive) | 빠름 | 임시 버퍼보다 약간 느림 | 1252ms | link |
Digit Counting (binary search) | 빠름 | 코드 길이 증가 | 1244ms | link |
Digit Counting (std::bit_width ) |
매우 빠름 | $10^i$ 전처리 필요 | 1212ms | link |
2-Digit Lookup Table | 빠름 | 전처리 필요, 코드 길이 증가 | 1228ms | link |
4-Digit Lookup Table | 매우 빠름 | 전처리 필요, 코드 길이 증가 | 1196ms | link |
4-Digit Lookup Table (optimized) | 매우 빠름 | 전처리 필요, 코드 길이 증가 | 1188ms | link |
SIMD(SSE2) | 매우 빠름 | 전처리 필요, 코드 길이 증가, 4-Digit LUT보다 약간 느림 | 1216ms | link |
fastest | - | - | 152ms | link |
빠른 출력 기법은 write
를 이용한 구현이 구현 난이도 대비 뚜렷한 성능 향상을 보였습니다. 대부분의 경우는 해당 기법을 이용하면 충분히 좋은 성능을 얻을 수 있을 것입니다.
가장 좋은 성능을 보인 기법은 Lookup Table $3$개를 이용하는 4-Digit Lookup Table 코드입니다. 이는 SSE2를 이용한 SIMD 기법보다 구현 난이도가 낮으면서 더 좋은 성능을 보이기에 최대한 성능을 최적화해야 하는 경우에 사용하면 적합할 것입니다.
눈여겨볼 점은, 빠른 입력, 출력 기법을 모두 적용한 코드의 실행 시간은 152ms에 불과해, 아무 최적화도 적용하지 않은 코드보다 약 $12.7$배 빠른 성능을 보였다는 점입니다. 이는 BOJ 11003처럼 입출력의 양이 많은 경우에 동일한 로직이라도 어떤 입출력 방식을 선택하느냐에 따라 실행 시간이 크게 달라질 수 있음을 나타냅니다.
Fast I/O 기법은 알고리즘 자체를 바꾸지 않고도 성능을 크게 개선할 수 있는 수단이므로, 익혀두면 다양한 문제에서 유용하게 활용될 수 있을 것입니다. 이상으로 글을 마치며, Fast I/O를 공부하고자 하는 분들께 이 글이 의미 있는 자료가 되었기를 바랍니다.
References
[1] https://github.com/miloyip/itoa-benchmark/tree/master
[2] https://github.com/fmtlib/fmt
[3] https://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10
[4] https://www.youtube.com/watch?v=fV6qYho-XVs
[5] https://blog.quarkslab.com/unaligned-accesses-in-cc-what-why-and-solutions-to-do-it-properly.html