AtCoder Regular Contest 075 C Meaningful Mean 풀이
문제 요약
$N$개의 수가 있다. $i$번째 수를 $a_i$라 하자.
수열 $a$의 $N$개의 수에서 연속한 수들을 고르는 방법의 가지수는 $\frac{N \times (N+1)}{2}$ 개이다.
$\frac{N \times (N+1)}{2}$ 개 중에서 평균이 $K$ 이상인 것이 몇개인지 구하라.
$ 1 \le N \le 2 \times 10^5, 1 \le K \le 10^9, 1 \le a_i \le 10^9 $
풀이
1. $O(N^3)$
Naive 한 방법.
$\frac{N \times (N+1)}{2}$ 개의 모든 경우에 대해서, 직접 합을 구하고 평균이 $K$ 이상인지를 확인한다.
2. $O(N^2)$
$S_i = \sum_{k=1}^{i} a_k$ 인 배열을 미리 구해놓자. $S_i$는 $1$번째 부터 $i$번째의 $a$ 수열의 합이며, 부분합으로 부른다. $S_0 = 0$ 으로 정의하자.
부분합을 이용하면 특정 구간 $[l, r]$에 대해서, 구간 합을 $S_r - S_{l-1}$ 으로 $O(1)$ 로 구할 수 있다.
즉, 1. 번 풀이에서 직접 합을 구하는 부분을 부분합을 이용하여 바꾸면 $O(N^2)$ 풀이가 완성된다.
3. $O(N \log N)$
아이디어를 소개한다. 입력받은 수에 모두 $K$를 빼준 수열을 $b_i = a_i - K$ 를 생각하자.
그럼 이제 수열 $a$에 대해서 문제를 풀지말고, 수열 $b$에 대해서 문제를 풀어보자.
수열 $b_l, b_{l+1}, \cdots, b_{r}$의 평균이 $K$ 이상인 것과 $b_l, b_{l+1}, \cdots, b_{r}$의 평균이 $0$ 이상인 것이 동치이다. 즉, 수열 $b$의 $\frac{N \times (N+1)}{2}$ 개의 합들 중에서 평균이 $0$ 이상인 것의 개수를 세는 문제로 바꾼다. 이제 이 문제를 풀어보자.
$[l, r]$의 합은 위에서 언급했듯이 $S_r - S_{l-1}$이다. 그리고 평균이 $0$ 이상이어야 하므로, 그 조건을 써보고 정리해보면,
$\frac{S_r - S_{l-1}}{r-l} \ge 0 \iff S_r - S_{l-1} \ge 0 \iff S_r \ge S_{l-1} $
이다.
즉, 수열 $S$에 대해서 $S_{l-1} \le S_{r}$를 만족하는 $(l, r)$ 순서쌍 개수를 찾는 문제와 같다. (단 $l \le r$)
자료구조를 하나 준비하자, 점 갱신, 구간 합 쿼리가 $O(\log N)$ 가능하면 모든 가능하다. 모든 점의 초기값을 $0$으로 되어 있다고 생각하자. 아래 코드는 BIT 를 이용하였다.
$r$을 하나씩 늘리면서, $[l, r]$의 평균이 $0$ 이상인 $l$의 개수를 찾는 방식으로 진행할 것이다. $S_i$들 각각이 하나의 점이 되고, $r$에 대한 조사가 끝난 경우 $S_r$에 대응 되는 점을 $1$로 갱신한다. 그러면 $S_{l-1} \le S_r$ 을 만족하는 $l$의 개수는 $[-\inf, S_r]$의 구간 합 쿼리 값과 같다.
그러면 각각의 $r$에 대해서 쿼리를 $2$번 씩 이용하므로, 시간 복잡도는 $O (N \log N)$이다.
코드
#include <bits/stdc++.h>
using namespace std;
#define REP(i,n) for(int (i)=0;(i)<(int)(n);(i)++)
#define SZ(v) ((int)(v).size())
#define ALL(v) (v).begin(),(v).end()
#define one first
#define two second
typedef long long ll;
typedef pair<int, int> pi;
const int INF = 0x3f2f1f0f;
const ll LINF = 1ll * INF * INF;
const int MAX_N = 2e5 + 100;
int N; ll K, Nr[MAX_N];
vector<ll> Co;
int ix(ll k) {return upper_bound(ALL(Co), k) - Co.begin() - 1;} // equal less (max index)
struct BIT {
ll A[MAX_N];
BIT() {
for(int i=0; i<MAX_N; i++) A[i] = 0;
}
void update(int p, ll v) {
for(; p < MAX_N; p += p &(-p)) A[p] += v;
}
ll getSum(int p) {
ll rv = 0; for(; p; p -= p & (-p)) rv += A[p];
return rv;
}
};
int main() {
cin >> N >> K;
for(int i=1; i<=N; i++) {
scanf("%lld", &Nr[i]);
Nr[i] -= K;
}
Co.push_back(-LINF);
Co.push_back(0);
for(int i=1; i<=N; i++) Nr[i] += Nr[i-1], Co.push_back(Nr[i]);
sort(ALL(Co)); Co.erase(unique(ALL(Co)), Co.end());
BIT bit; bit.update(ix(0), 1);
ll ans = 0;
for(int i=1; i<=N; i++) {
ans += bit.getSum(ix(Nr[i]));
bit.update(ix(Nr[i]), 1);
}
printf("%lld\n", ans);
return 0;
}