jinhan814's profile image

jinhan814

May 11, 2025 00:00

read, write, mmap을 이용한 Fast I/O 구현 (2)

algorithm , problem solving

이번 글에서는 이전 글인 read, write, mmap을 이용한 Fast I/O 구현 (1)에 이어 Bit-Twiddling Hack을 적용한 더 빠른 입력 기법과 Lookup Table 및 SIMD를 활용한 출력 기법을 살펴보겠습니다.

1. Faster Input Implementation

read 또는 mmap을 이용한 빠른 입력 방식은 입력 버퍼를 두고 버퍼 단위로 한 번에 입력을 읽어오며 입력 속도를 향상시킵니다. 대부분의 문제에서는 이 정도 최적화만으로 충분하지만, 한 번에 여러 문자를 처리할 수 있다면 입력 속도를 더욱 높일 수 있습니다.

1.1 8-Byte Parsing

앞서 소개한 Fast I/O는 입력 버퍼를 통해 $1$바이트씩 문자를 읽는 방식이었습니다. Bit-Twiddling 기법을 활용하면, $8$바이트 단위로 입력을 가져와 $64$비트 정수 자료형의 비트 연산으로 한 번에 여러 자리 숫자를 처리할 수 있습니다.

auto read_int = [&] {
	int ret = 0, flag = 0;
	if (*pr == '-') flag = 1, pr++;
	u64 x; memcpy(&x, pr, 8);
	x ^= 0x3030303030303030;
	if (!(x & 0xf0f0f0f0f0f0f0f0)) {
		x = ((x * 10) + (x >> 8)) & 0x00ff00ff00ff00ff;
		x = ((x * 100) + (x >> 16)) & 0x0000ffff0000ffff;
		x = ((x * 10000) + (x >> 32)) & 0x00000000ffffffff;
		ret = x;
		pr += 8;
	}
	while (*pr & 16) ret = 10 * ret + (*pr++ & 15);
	pr++;
	if (flag) ret = -ret;
	return ret;
};

코드에서 memcpy(&x, pr, 8)pr[0], $\cdots$, pr[7]을 $64$비트 정수 자료형 x에 저장합니다.

ASCII 코드의 성질을 이용하면 x를 구성하는 $8$개의 문자가 모두 숫자인지 여부를 알아낼 수 있습니다.

숫자 '0', $\cdots$, '9'는 ASCII 코드에서 상위 $4$비트가 항상 0011____로 고정되어 있습니다. 따라서 입력 문자를 0x30 = 0b00110000과 xor 연산하면 각 문자의 상위 $4$비트가 0000____이 됩니다. 즉, '0'0x00, '1'0x01, $\cdots$, '9'0x09로 변환됩니다.

반면, 공백(0x20), 줄바꿈(0x0A), 하이픈(0x2D) 등의 숫자가 아닌 문자는 xor 결과에서 상위 $4$개 비트 중 1인 비트가 남아있습니다. 따라서 xor 결과를 0xf0 = 0b11110000와 bitwise and 연산을 한 뒤 결과가 $0$인지 확인하면 문자가 숫자인지 확인할 수 있습니다. 코드에서는 x & 0xf0f0f0f0f0f0f0f0의 결과가 $0$인지 확인하며 $8$개 문자 중 공백 또는 줄바꿈이 존재하는지 확인합니다.

이제 $8$개 문자가 모두 숫자일 때, 전체 숫자를 이어붙인 값을 빠르게 구하는 과정을 알아보겠습니다.

x = ((x * 10) + (x >> 8)) & 0x00ff00ff00ff00ff;
x = ((x * 100) + (x >> 16)) & 0x0000ffff0000ffff;
x = ((x * 10000) + (x >> 32)) & 0x00000000ffffffff;

초기 상태에 x에는 $8$개의 숫자가 차례로 저장되어 있습니다. 이를 $2$개씩 묶어서 이어붙이는 과정을 $3$번 반복하면 전체 숫자를 이어붙인 값을 구할 수 있습니다.

첫 번째 줄에서는 $1$바이트씩 숫자를 이어붙입니다. 각 $1$바이트에 저장된 값을 순서대로 $d_0, d_1, \cdots, d_7$이라 합시다. 인접한 두 값 $(d_{2i}, d_{2i+1})$은 (x * 10) + (x >> 8)에서 $(10d_{2i} + d_{2i+1}, *)$이 되고, 0x00ff와 bitwise and 연산을 통해 $(10d_{2i} + d_{2i+1}, 0)$이 됩니다. 따라서 첫 번째 줄이 실행되면 $x$의 각 $2$바이트에 두 자리 숫자가 저장됩니다.

비슷하게 두 번째 줄에서는 $2$바이트씩 숫자를 이어붙입니다. 인접한 두 값 $(d_{2i}, d_{2i+1})$은 (x * 100) + (x >> 16)에서 $(100d_{2i} + d_{2i+1}, 0)$이 되고, 0x0000ffff와 bitwise and 연산을 통해 $(100d_{2i} + d_{2i + 1}, 0)$이 됩니다.

마지막으로 세 번째 줄에서는 $4$바이트씩 숫자를 이어붙입니다. (x * 10000) + (x >> 32)0x00000000ffffffff를 이용하면 $8$바이트에 여덟 자리 숫자가 저장됩니다. 이 값을 정답에 기록한 뒤 pr를 $8$만큼 이동시키면 한 번에 $8$글자를 처리할 수 있습니다.

이후 아직 처리되지 않은 글자를 한 글자씩 순차적으로 읽으며 처리하면 전체 입력을 처리할 수 있습니다.

while (*pr & 16) ret = 10 * ret + (*pr++ & 15);

사용 예시는 다음과 같습니다. (코드)

auto read_int = [&] {
	int ret = 0, flag = 0;
	if (*pr == '-') flag = 1, pr++;
	u32 x; memcpy(&x, pr, 4);
	x ^= 0x30303030;
	if (!(x & 0xf0f0f0f0)) {
		x = ((x * 10) + (x >> 8)) & 0x00ff00ff;
		x = ((x * 100) + (x >> 16)) & 0x0000ffff;
		ret = x;
		pr += 4;
	}
	while (*pr & 16) ret = 10 * ret + (*pr++ & 15);
	pr++;
	if (flag) ret = -ret;
	return ret;
};

비슷하게, $8$바이트 단위 대신 $4$바이트 단위로 숫자를 처리할 수도 있습니다. 이는 입력되는 숫자가 $8$자리 미만인 경우에 유용합니다.

사용 예시는 다음과 같습니다. (코드)

1.2 8-4-2-1 Byte Parsing

auto read_int = [&] {
	int ret = 0, flag = 0;
	if (*pr == '-') flag = 1, pr++;
	{
		u64 x; memcpy(&x, pr, 8);
		x ^= 0x3030303030303030;
		if (!(x & 0xf0f0f0f0f0f0f0f0)) {
			x = ((x * 10) + (x >> 8)) & 0x00ff00ff00ff00ff;
			x = ((x * 100) + (x >> 16)) & 0x0000ffff0000ffff;
			x = ((x * 10000) + (x >> 32)) & 0x00000000ffffffff;
			ret = x;
			pr += 8;
		}
	}
	{
		u32 x; memcpy(&x, pr, 4);
		x ^= 0x30303030;
		if (!(x & 0xf0f0f0f0)) {
			x = ((x * 10) + (x >> 8)) & 0x00ff00ff;
			x = ((x * 100) + (x >> 16)) & 0x0000ffff;
			ret = 10000 * ret + x;
			pr += 4;
		}
	}
	{
		u16 x; memcpy(&x, pr, 2);
		x ^= 0x3030;
		if (!(x & 0xf0f0)) {
			x = ((x * 10) + (x >> 8)) & 0x00ff;
			ret = 100 * ret + x;
			pr += 2;
		}
	}
	if (*pr & 16) ret = 10 * ret + (*pr++ & 15);
	pr++;
	if (flag) ret = -ret;
	return ret;
};

$8$바이트 단위로 한 번에 여덟 자리 숫자를 처리하는 방법을 응용하면 $8, 4, 2, 1$바이트 단위로 한 번에 여덟 자리, 네 자리, 두 자리, 한 자리 숫자를 처리하며 $4$번의 과정으로 $15$자리 이하의 정수 자료형을 입력받을 수 있습니다.

사용 예시는 다음과 같습니다. (코드)

1.3 8-Byte Parsing with std::countr_zero

auto read_int = [&] {
	u64 x; memcpy(&x, pr, 8);
	x ^= 0x3030303030303030;
	int t = std::countr_zero(x & 0xf0f0f0f0f0f0f0f0) >> 3;
	x <<= 64 - (t << 3);
	x = (x * 10 + (x >> 8)) & 0x00ff00ff00ff00ff;
	x = (x * 100 + (x >> 16)) & 0x0000ffff0000ffff;
	x = (x * 10000 + (x >> 32)) & 0x00000000ffffffff;
	pr += t + 1;
	return x;
};

입력되는 정수가 $8$자리 이하라면 std::countr_zero를 이용해 정수를 조건 분기 없이 파싱할 수 있습니다.

pr[0], $\cdots$, pr[7], pr[8] 중 숫자가 아닌 문자가 처음으로 등장하는 위치는 x & 0xf0f0f0f0f0f0f0f0std::countr_zero를 적용해 알아낼 수 있습니다. std::countr_zerobit 헤더에 정의된 함수로 가장 작은 비트부터 연속한 0의 개수를 반환합니다. 이를 $8$로 나눈 값을 t라 하면 pr[t]는 처음으로 숫자가 아닌 문자가 등장하는 위치가 됩니다.

x <<= 64 - (t << 3)에서는 x에서 pr[t] 이후 부분을 없애고 가장 작은 자릿수의 숫자가 최상위 바이트에 위치하도록 합니다. 이후 $8$바이트 단위로 입력을 가져와 비트 연산으로 숫자를 이어붙이는 방식을 사용하면 $8$자리 이하 정수를 조건 분기 없이 파싱할 수 있습니다.

사용 예시는 다음과 같습니다. (코드)

auto read_int = [&] {
	u32 x; memcpy(&x, pr, 4);
	x ^= 0x30303030;
	int t = std::countr_zero(x & 0xf0f0f0f0) >> 3;
	x <<= 32 - (t << 3);
	x = (x * 10 + (x >> 8)) & 0x00ff00ff;
	x = (x * 100 + (x >> 16)) & 0x0000ffff;
	pr += t + 1;
	return x;
};

비슷하게, 입력되는 정수가 $4$자리 이하라면 $32$비트 정수 자료형을 이용해 정수를 조건 분기 없이 파싱할 수 있습니다.

사용 예시는 다음과 같습니다. (코드)

2. Faster Output Implementation

이전 글에서는 write 함수를 활용한 빠른 출력 구현 방법에 대해 살펴보았습니다. 빠른 출력은 x % 10 연산을 통해 가장 작은 자릿수부터 숫자를 버퍼에 기록하는 방식으로 작동하므로, 출력 순서가 반대입니다. 이를 해결하기 위해 빠른 출력의 구현에서는 숫자를 임시 버퍼에 순차적으로 기록한 후 이를 역순으로 출력 버퍼에 복사하거나, 미리 x의 자릿수 sz를 계산한 후 역순으로 직접 출력 버퍼에 기록하는 방식을 사용합니다.

구현에서 주요 오버헤드는 두 가지입니다. 하나는 x % 10x / 10에서 발생하는 나눗셈 연산이며, 다른 하나는 숫자를 한 자리씩 기록하는 것입니다. 따라서 나눗셈 연산의 빈도를 줄이고, 여러 자릿수를 한 번에 출력 버퍼에 기록할 수 있다면 출력 속도를 더욱 향상시킬 수 있습니다.

2.1 Digit Counting with Bit-Twiddling Hack

constexpr int p10[] = {
	0, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000
};

auto count_digit = [](u32 n) {
	int t = std::bit_width(n) * 1233 >> 12;
	return t - (n < p10[t]) + 1;
};

auto write_int = [&](int x) {
	if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
	if (x < 0) *pw++ = '-', x = -x;
	int sz = count_digit(x);
	for (int i = sz - 1; i >= 0; i--) pw[i] = x % 10 | 48, x /= 10;
	pw += sz;
};

x의 자릿수를 구하는 고속 출력 기법에서 x의 자릿수 sz는 $\lfloor\log_{10} x\rfloor$번의 나눗셈 연산 또는 $3$번의 조건 분기 대신 $1$번의 std::bit_width 호출을 이용해 구할 수 있습니다.

std::bit_widthbit 헤더에 정의된 함수로 $\lceil\log_2(n+1)\rceil$을 반환합니다. 여기에 $\log_{10} 2 = 0.301029995664\dots \approx \frac{1233}{4096}$를 곱하면 std::bit_width(n) * 1233 >> 12로 $\lfloor\log_{10} n\rfloor$을 근사할 수 있습니다. (참고)

오차 범위는 다음과 같이 계산됩니다.

\[\begin{align*} x & =\lceil\log_2(n+1)\rceil \\ c & =\frac{1233}{4096}=0.301025390625 \\ k & =\log_{10}2\approx 0.301029995664 \\ e & =\lfloor xc\rfloor - \lfloor\log_{10}n\rfloor \end{align*}\]

$1 \leq x \leq 680$이라면 $\lfloor xc\rfloor = \lfloor xk\rfloor$이 성립합니다.

$n \geq 1$이라면 $\lceil\log_2(n+1)\rceil = \lfloor\log_2n\rfloor + 1$에서 다음이 성립합니다.

\[\begin{align*} & & 2^{x-1} \leq n \leq 2^x \\ \Rightarrow & & (x-1)k \leq \log_{10}n \leq xk \\ \Rightarrow & & \lfloor(x-1)k\rfloor \leq \lfloor\log_{10}n\rfloor \leq \lfloor xk\rfloor \\ \Rightarrow & & 0 \leq \lfloor xk\rfloor - \lfloor\log_{10}n\rfloor \leq 1 \end{align*}\]

따라서 $1 \leq n < 2^{680}$ 범위에서 $0 \leq e \leq 1$이 성립하고, 모든 $128$ 비트 이하 unsigned integer에 대해 std::bit_width(n) * 1233 >> 12는 $\lfloor\log_{10} n\rfloor$ 또는 $\lfloor\log_{10} n\rfloor + 1$입니다. 이때 로그 값이 $1$ 크게 구해진 경우는 $10^0, 10^1, \cdots$를 미리 구해두면 $1$번의 비교 연산으로 보정할 수 있습니다.

이상의 과정을 코드로 옮기면 std::bit_width을 이용해 count_digit 함수를 구현할 수 있고, write_int 함수는 이전 글과 동일한 방법으로 구현할 수 있습니다.

사용 예시는 다음과 같습니다. (코드)

2.2 2-Digit Lookup Table

constexpr char lut[200] = {
	'0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9',
	'1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9',
	'2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9',
	'3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9',
	'4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9',
	'5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9',
	'6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9',
	'7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9',
	'8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9',
	'9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9'
};

auto write_int = [&](int x) {
	if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
	if (x < 0) *pw++ = '-', x = -x;
	char t[10], *pt = t;
	while (x >= 100) {
		int i = x % 100 << 1;
		x /= 100;
		*pt++ = lut[i + 1];
		*pt++ = lut[i];
	}
	if (x < 10) {
		*pt++ = x + '0';
	}
	else {
		int i = x << 1;
		*pt++ = lut[i + 1];
		*pt++ = lut[i];
	}
	do *pw++ = *--pt; while (pt != t);
};

$0 \leq i < 100$ 범위의 $i$에 대해 i / 10i % 10lut 배열에 전처리해두면 x를 $100$으로 나누며 두 자리씩 출력할 수 있습니다. 이는 나눗셈 횟수를 절반으로 줄일 수 있어 x를 $10$으로 나누는 방식보다 더 빠르게 동작합니다.

이러한 최적화 기법에서 사용되는 전처리 배열을 lookup table이라 합니다.

사용 예시는 다음과 같습니다. (코드)

constexpr auto lut = [] {
	std::array<char, 200> res;
	for (int i = 0; i < 100; i++) {
		res[2 * i] = i / 10 | 48;
		res[2 * i + 1] = i % 10 | 48;
	}
	return res;
}();

constexpr 키워드를 이용하면 컴파일 타임에 lut를 계산할 수 있어서 전역에 상수 값을 직접 넣어주지 않아도 됩니다.

사용 예시는 다음과 같습니다. (코드)

2.3 4-Digit Lookup Table

constexpr auto lut = [] {
	std::array<std::array<char, 4>, 10000> res;
	for (int i = 0; i < 10000; i++) {
		res[i][0] = i / 1000 | 48;
		res[i][1] = i / 100 % 10 | 48;
		res[i][2] = i / 10 % 10 | 48;
		res[i][3] = i % 10 | 48;
	}
	return res;
}();

auto write_4digit_trimmed = [&](int x) {
	if (x > 999) {
		memcpy(pw, &lut[x], 4);
		pw += 4;
	}
	else if (x > 99) {
		memcpy(pw, &lut[x * 10], 3);
		pw += 3;
	}
	else if (x > 9) {
		memcpy(pw, &lut[x * 100], 2);
		pw += 2;
	}
	else {
		memcpy(pw, &lut[x * 1000], 1);
		pw += 1;
	}
};

auto write_4digit = [&](int x) {
	memcpy(pw, &lut[x], 4);
	pw += 4;
};

auto write_int = [&](int x) {
	if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
	if (x < 0) *pw++ = '-', x = -x;
	if (x > 99999999) {
		write_4digit_trimmed(x / 10000 / 10000);
		write_4digit(x / 10000 % 10000);
		write_4digit(x % 10000);
	}
	else if (x > 9999) {
		write_4digit_trimmed(x / 10000);
		write_4digit(x % 10000);
	}
	else {
		write_4digit_trimmed(x);
	}
};

$0 \leq i < 10\,000$ 범위의 $i$에 대해 $i$의 $j$번째 자릿수를 lut[i][4 - j]에 저장해 lookup table을 구성하면, 숫자를 문자열로 변환할 때 한 번에 $4$자리씩 버퍼를 채울 수 있습니다.

write_4digit_trimmedwrite_4digit 함수는 $4$자리 이하 또는 정확히 $4$자리 숫자를 memcpy로 한 번에 버퍼에 복사하는 함수입니다. write_int 함수는 x의 자릿수에 대해 조건 분기를 하며 해당 함수를 이용해 x를$4$자리씩 끊어서 출력합니다.

이렇게 $4$자리씩 memcpy를 이용해 lookup table의 값을 버퍼에 복사하면 한 번에 여러 개의 값을 처리할 수 있어 각 자릿수를 하나씩 처리하는 방식보다 효율적입니다.

사용 예시는 다음과 같습니다. (코드)

constexpr auto lut_digit_trimmed = [] {
	std::array<std::array<char, 4>, 10000> res{};
	for (int i = 0; i < 10; i++) {
		res[i][0] = i | 48;
	}
	for (int i = 10; i < 100; i++) {
		res[i][0] = i / 10 | 48;
		res[i][1] = i % 10 | 48;
	}
	for (int i = 100; i < 1000; i++) {
		res[i][0] = i / 100 | 48;
		res[i][1] = i / 10 % 10 | 48;
		res[i][2] = i % 10 | 48;
	}
	for (int i = 1000; i < 10000; i++) {
		res[i][0] = i / 1000 | 48;
		res[i][1] = i / 100 % 10 | 48;
		res[i][2] = i / 10 % 10 | 48;
		res[i][3] = i % 10 | 48;
	}
	return res;
}();

constexpr auto lut_digit = [] {
	std::array<std::array<char, 4>, 10000> res;
	for (int i = 0; i < 10000; i++) {
		res[i][0] = i / 1000 | 48;
		res[i][1] = i / 100 % 10 | 48;
		res[i][2] = i / 10 % 10 | 48;
		res[i][3] = i % 10 | 48;
	}
	return res;
}();

constexpr auto lut_sz = [] {
	std::array<char, 10000> res;
	for (int i = 0; i < 10; i++) res[i] = 1;
	for (int i = 10; i < 100; i++) res[i] = 2;
	for (int i = 100; i < 1000; i++) res[i] = 3;
	for (int i = 1000; i < 10000; i++) res[i] = 4;
	return res;
}();

auto write_4digit_trimmed = [&](int x) {
	memcpy(pw, &lut_digit_trimmed[x], 4);
	pw += lut_sz[x];
};

auto write_4digit = [&](int x) {
	memcpy(pw, &lut_digit[x], 4);
	pw += 4;
};

auto write_int = [&](int x) {
	if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
	if (x < 0) *pw++ = '-', x = -x;
	if (x > 99999999) {
		write_4digit_trimmed(x / 10000 / 10000);
		write_4digit(x / 10000 % 10000);
		write_4digit(x % 10000);
	}
	else if (x > 9999) {
		write_4digit_trimmed(x / 10000);
		write_4digit(x % 10000);
	}
	else {
		write_4digit_trimmed(x);
	}
};

$4$자리 미만 수의 leading zero를 제외한 표현과 자릿수에 대한 lookup table을 추가로 구성하면 write_4digit_trimmed의 구현에서 조건 분기를 없앨 수 있습니다.

사용 예시는 다음과 같습니다. (코드)

2.4 SIMD(SSE2) Vectorized Output

alignas(16) constexpr u32 kDiv10000Vec[4]  = { 3518437209, 3518437209, 3518437209, 3518437209 };
alignas(16) constexpr u32 k10000Vec[4]     = { 10000, 10000, 10000, 10000 };
alignas(16) constexpr u16 kDivPwrVec[8]    = { 8389, 5243, 13108, 32768, 8389, 5243, 13108, 32768 };
alignas(16) constexpr u16 k10Vec[8]        = { 10, 10, 10, 10, 10, 10, 10, 10 };
alignas(16) constexpr u16 kShiftPwrVec[8]  = { 128, 2048, 8192, 32768, 128, 2048, 8192, 32768 };
alignas(16) constexpr char kAsciiZero[16]  = { 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48 };

inline __m128i convert8digits(u32 x) {
	__m128i vv   = _mm_cvtsi32_si128(x);
	__m128i abcd = _mm_srli_epi64(_mm_mul_epu32(vv, ((__m128i*)(kDiv10000Vec))[0]), 45);
	__m128i efgh = _mm_sub_epi32(vv, _mm_mul_epu32(abcd, ((__m128i*)(k10000Vec))[0]));
	__m128i v1   = _mm_unpacklo_epi16(abcd, efgh);
	__m128i v1x4 = _mm_slli_epi64(v1, 2);
	__m128i v2a  = _mm_unpacklo_epi16(v1x4, v1x4);
	__m128i v2   = _mm_unpacklo_epi32(v2a, v2a);
	__m128i v3   = _mm_mulhi_epu16(v2, ((__m128i*)(kDivPwrVec))[0]);
	__m128i v4   = _mm_mulhi_epu16(v3, ((__m128i*)(kShiftPwrVec))[0]);
	__m128i v5   = _mm_mullo_epi16(v4, ((__m128i*)(k10Vec))[0]);
	__m128i v6   = _mm_slli_epi64(v5, 16);
	return _mm_sub_epi16(v4, v6);
}

auto write_int = [&](int x) {
	if (pw - w + 40 > wbuf_sz) write(1, w, pw - w), pw = w;
	if (x < 0) *pw++ = '-', x = -x;
	if (x > 99999999) {
		write_4digit_trimmed(x / 100000000);
		x %= 100000000;
		__m128i digits = convert8digits(x);
		__m128i ascii = _mm_add_epi8(
			_mm_packus_epi16(_mm_setzero_si128(), digits),
			((__m128i*)(kAsciiZero))[0]
		);
		ascii = _mm_srli_si128(ascii, 8);
		_mm_storel_epi64((__m128i*)(pw), ascii);
		pw += 8;
	}
	else if (x > 9999) {
		write_4digit_trimmed(x / 10000);
		write_4digit(x % 10000);
	}
	else {
		write_4digit_trimmed(x);
	}
};

SIMD(Single Instruction Multiple Data)를 이용하면 $32$비트 정수를 한 번에 $8$자리씩 변환해 버퍼에 기록할 수 있습니다. (참고)

구현은 4-digit lookup table과 유사합니다. 만약 출력하는 수가 $8$자리 이하라면 lutmemcpy를 이용해 출력을 처리할 수 있고, $9$자리 이상이라면 앞부분을 lut를 이용해 출력한 뒤 남은 $8$자리를 convert8digits 함수를 이용해 SIMD로 처리합니다.

convert8digits 함수는 기존의 Bit-Twiddling Hack 기반 고속 입력 기법에서 $8$자리 숫자를 처리하던 과정을 역으로 수행하여 구현할 수 있습니다. 이 과정에서는 각 자릿수에 대한 산술 연산이 필요하므로, 단순한 $64$비트 정수 자료형만으로는 충분하지 않으며, 대신 $128$비트 벡터 자료형인 __m128i와 SSE2 명령어 집합을 활용하여 병렬적으로 연산을 수행합니다. 나눗셈 연산은 성능 향상을 위해 Barrett Reduction 기법을 사용하여 곱셈과 bitwise shift 연산으로 대체됩니다.

이 방식은 $9$자리 이상의 정수를 $8$자리 단위로 분할하여 한 번에 출력할 수 있어, 이론적으로는 성능 향상을 기대할 수 있습니다. 그러나 실제 테스트 결과, SIMD를 이용해 $8$자리를 병렬로 처리하는 방식보다 4-digit lookup table을 이용해 $4$자리씩 처리하는 방식이 더 좋은 성능을 보였습니다. 이는 룩업 테이블 기반 방식이 memcpy와 같은 단순한 연산만으로 구현되기 때문인 것으로 생각됩니다.

사용 예시는 다음과 같습니다. (코드)

3. Summary

이번 글에서는 저수준 입출력 기법을 한층 더 최적화하기 위해 Bit-Twiddling Hack, Lookup Table, 그리고 SIMD(Vectorized Instruction) 기술을 활용한 입력 및 출력 기법을 다루었습니다. 입력 최적화에서는 $8$바이트 정수를 묶어서 처리하는 방식, 그리고 std::countr_zero와 같은 비트 연산 기반 기술을 적용하여 조건 분기 없이 정수를 파싱하는 방법을 살펴보았습니다. 이는 한 번에 여러 값을 처리하며 입력 속도를 향상합니다. 출력 최적화에서는 나눗셈 연산의 횟수를 줄이고 여러 자릿수를 한 번에 출력하기 위해 2-digit 및 4-digit lookup table을 이용하였으며, SSE2 명령어 집합을 활용한 $8$바이트 출력 기법을 살펴보았습니다.

지금까지 살펴본 구현 방식의 장단점, 성능을 정리하면 다음과 같습니다.

[1] Fast Input Methods: BOJ 11003

빠른 입력기법 장점 단점 실행 시간 코드
scanf C 표준 입력 느림 1932ms link
cin C++ 표준 입력 느림 1412ms link
read 메모리를 적게 사용 약간 느림 876ms link
mmap 빠름 메모리를 많이 사용 848ms link
8-Byte Parsing 매우 빠름 메모리를 많이 사용 812ms link
8-4-2-1 Byte Parsing 매우 빠름 메모리를 많이 사용 808ms link
fastest - - 152ms link

결과를 보면 read, mmap를 이용한 기본적인 빠른 입력 구현도 실행 시간을 절반 이상 단축할 정도로 충분히 성능이 좋은 걸 알 수 있습니다. 대부분의 경우는 이정도 구현으로 충분한 성능을 얻을 수 있을 것입니다.

가장 좋은 성능을 보인 방법은 Bit-Twiddling Hack을 이용한 8-Byte Parsing입니다. 추가적인 최적화가 필요하다면 해당 방법을 이용하면 좋을 것입니다.

[2] Fast Conditioned Input Methods: BOJ 15552

빠른 입력기법 장점 단점 실행 시간 코드
scanf C 표준 입력 느림 288ms link
cin C++ 표준 입력 느림 220ms link
read 메모리를 적게 사용 약간 느림 108ms link
read ($x \geq 0$, single space) 메모리를 적게 사용 약간 느림 104ms link
mmap 빠름 메모리를 많이 사용 104ms link
mmap ($x \geq 0$, single space) 빠름 메모리를 많이 사용 100ms link
mmap with bitwise operation ($x \geq 0$, single space) 빠름 메모리를 많이 사용 96ms link
8-Byte Parsing with std::countr_zero ($x \geq 0$, single space) 매우 빠름 메모리를 많이 사용, $x$가 $8$자리 이하여야 함 92ms link
4-Byte Parsing with std::countr_zero ($x \geq 0$, single space) 매우 빠름 메모리를 많이 사용, $x$가 $4$자리 이하여야 함 92ms link
fastest - - 12ms link

입력 파일이 특정 조건을 만족한다고 가정하면 빠른 입력 기법을 더욱 최적화할 수 있습니다.

테스트에 사용한 문제는 입력 범위가 $[1, 10^3]$이기 때문에 std::countr_zero를 이용한 4-Byte Parsing이 가장 좋은 성능을 보였습니다.

[3] Fast Output Methods: BOJ 11003

빠른 출력기법 장점 단점 실행 시간 코드
printf C 표준 출력 느림 1932ms link
cout C++ 표준 출력 느림 1584ms link
write 빠름 - 1248ms link
Digit Counting (naive) 빠름 임시 버퍼보다 약간 느림 1252ms link
Digit Counting (binary search) 빠름 코드 길이 증가 1244ms link
Digit Counting (std::bit_width) 매우 빠름 $10^i$ 전처리 필요 1212ms link
2-Digit Lookup Table 빠름 전처리 필요, 코드 길이 증가 1228ms link
4-Digit Lookup Table 매우 빠름 전처리 필요, 코드 길이 증가 1196ms link
4-Digit Lookup Table (optimized) 매우 빠름 전처리 필요, 코드 길이 증가 1188ms link
SIMD(SSE2) 매우 빠름 전처리 필요, 코드 길이 증가, 4-Digit LUT보다 약간 느림 1216ms link
fastest - - 152ms link

빠른 출력 기법은 write를 이용한 구현이 구현 난이도 대비 뚜렷한 성능 향상을 보였습니다. 대부분의 경우는 해당 기법을 이용하면 충분히 좋은 성능을 얻을 수 있을 것입니다.

가장 좋은 성능을 보인 기법은 Lookup Table $3$개를 이용하는 4-Digit Lookup Table 코드입니다. 이는 SSE2를 이용한 SIMD 기법보다 구현 난이도가 낮으면서 더 좋은 성능을 보이기에 최대한 성능을 최적화해야 하는 경우에 사용하면 적합할 것입니다.

눈여겨볼 점은, 빠른 입력, 출력 기법을 모두 적용한 코드의 실행 시간은 152ms에 불과해, 아무 최적화도 적용하지 않은 코드보다 약 $12.7$배 빠른 성능을 보였다는 점입니다. 이는 BOJ 11003처럼 입출력의 양이 많은 경우에 동일한 로직이라도 어떤 입출력 방식을 선택하느냐에 따라 실행 시간이 크게 달라질 수 있음을 나타냅니다.

Fast I/O 기법은 알고리즘 자체를 바꾸지 않고도 성능을 크게 개선할 수 있는 수단이므로, 익혀두면 다양한 문제에서 유용하게 활용될 수 있을 것입니다. 이상으로 글을 마치며, Fast I/O를 공부하고자 하는 분들께 이 글이 의미 있는 자료가 되었기를 바랍니다.

References

[1] https://github.com/miloyip/itoa-benchmark/tree/master

[2] https://github.com/fmtlib/fmt

[3] https://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10

[4] https://www.youtube.com/watch?v=fV6qYho-XVs

[5] https://blog.quarkslab.com/unaligned-accesses-in-cc-what-why-and-solutions-to-do-it-properly.html

[6] https://judge.yosupo.jp/submission/279060