kajebiii's profile image

kajebiii

September 17, 2019 20:25

AtCoder Regular Contest 075 C Meaningful Mean 풀이

problem-solving , programming-contest , Atcoder

문제 요약

문제 링크

$N$개의 수가 있다. $i$번째 수를 $a_i$라 하자.

수열 $a$의 $N$개의 수에서 연속한 수들을 고르는 방법의 가지수는 $\frac{N \times (N+1)}{2}$ 개이다.

$\frac{N \times (N+1)}{2}$ 개 중에서 평균이 $K$ 이상인 것이 몇개인지 구하라.

$ 1 \le N \le 2 \times 10^5, 1 \le K \le 10^9, 1 \le a_i \le 10^9 $

풀이

1. $O(N^3)$

Naive 한 방법.

$\frac{N \times (N+1)}{2}$ 개의 모든 경우에 대해서, 직접 합을 구하고 평균이 $K$ 이상인지를 확인한다.

2. $O(N^2)$

$S_i = \sum_{k=1}^{i} a_k$ 인 배열을 미리 구해놓자. $S_i$는 $1$번째 부터 $i$번째의 $a$ 수열의 합이며, 부분합으로 부른다. $S_0 = 0$ 으로 정의하자.

부분합을 이용하면 특정 구간 $[l, r]$에 대해서, 구간 합을 $S_r - S_{l-1}$ 으로 $O(1)$ 로 구할 수 있다.

즉, 1. 번 풀이에서 직접 합을 구하는 부분을 부분합을 이용하여 바꾸면 $O(N^2)$ 풀이가 완성된다.

3. $O(N \log N)$

아이디어를 소개한다. 입력받은 수에 모두 $K$를 빼준 수열을 $b_i = a_i - K$ 를 생각하자.

그럼 이제 수열 $a$에 대해서 문제를 풀지말고, 수열 $b$에 대해서 문제를 풀어보자.

수열 $b_l, b_{l+1}, \cdots, b_{r}$의 평균이 $K$ 이상인 것과 $b_l, b_{l+1}, \cdots, b_{r}$의 평균이 $0$ 이상인 것이 동치이다. 즉, 수열 $b$의 $\frac{N \times (N+1)}{2}$ 개의 합들 중에서 평균이 $0$ 이상인 것의 개수를 세는 문제로 바꾼다. 이제 이 문제를 풀어보자.

$[l, r]$의 합은 위에서 언급했듯이 $S_r - S_{l-1}$이다. 그리고 평균이 $0$ 이상이어야 하므로, 그 조건을 써보고 정리해보면,

$\frac{S_r - S_{l-1}}{r-l} \ge 0 \iff S_r - S_{l-1} \ge 0 \iff S_r \ge S_{l-1} $

이다.

즉, 수열 $S$에 대해서 $S_{l-1} \le S_{r}$를 만족하는 $(l, r)$ 순서쌍 개수를 찾는 문제와 같다. (단 $l \le r$)

자료구조를 하나 준비하자, 점 갱신, 구간 합 쿼리가 $O(\log N)$ 가능하면 모든 가능하다. 모든 점의 초기값을 $0$으로 되어 있다고 생각하자. 아래 코드는 BIT 를 이용하였다.

$r$을 하나씩 늘리면서, $[l, r]$의 평균이 $0$ 이상인 $l$의 개수를 찾는 방식으로 진행할 것이다. $S_i$들 각각이 하나의 점이 되고, $r$에 대한 조사가 끝난 경우 $S_r$에 대응 되는 점을 $1$로 갱신한다. 그러면 $S_{l-1} \le S_r$ 을 만족하는 $l$의 개수는 $[-\inf, S_r]$의 구간 합 쿼리 값과 같다.

그러면 각각의 $r$에 대해서 쿼리를 $2$번 씩 이용하므로, 시간 복잡도는 $O (N \log N)$이다.

코드

#include <bits/stdc++.h>
 
using namespace std;
 
#define REP(i,n) for(int (i)=0;(i)<(int)(n);(i)++)
#define SZ(v) ((int)(v).size())
#define ALL(v) (v).begin(),(v).end()
#define one first
#define two second
typedef long long ll;
typedef pair<int, int> pi;
const int INF = 0x3f2f1f0f;
const ll LINF = 1ll * INF * INF;
 
const int MAX_N = 2e5 + 100;
 
int N; ll K, Nr[MAX_N];
vector<ll> Co;
int ix(ll k) {return upper_bound(ALL(Co), k) - Co.begin() - 1;} // equal less (max index)
struct BIT {
	ll A[MAX_N];
	BIT() {
		for(int i=0; i<MAX_N; i++) A[i] = 0;
	}
	void update(int p, ll v) {
		for(; p < MAX_N; p += p &(-p)) A[p] += v;
	}
	ll getSum(int p) {
		ll rv = 0; for(; p; p -= p & (-p)) rv += A[p];
		return rv;
	}
};
int main() {
	cin >> N >> K; 
	for(int i=1; i<=N; i++) {
		scanf("%lld", &Nr[i]);
		Nr[i] -= K;
	}
	Co.push_back(-LINF);
	Co.push_back(0);
	for(int i=1; i<=N; i++) Nr[i] += Nr[i-1], Co.push_back(Nr[i]);
	sort(ALL(Co)); Co.erase(unique(ALL(Co)), Co.end());
 
	BIT bit; bit.update(ix(0), 1);
	ll ans = 0;
	for(int i=1; i<=N; i++) {
		ans += bit.getSum(ix(Nr[i]));
		bit.update(ix(Nr[i]), 1);
	}
	printf("%lld\n", ans);
	return 0;
}