Official

F - Double Sum Editorial by en_translator


This problem is a basic exercise of the algorithm called sweep line algorithm. If you could not solve this problem, we recommend you to learn the sweep line algorithm.

The sought double sum is

\[\sum_{i=1}^N \sum_{i < j} \max(A_j - A_i, 0).\]

Here, the contribution to the double some for a fixed \(i\) can be represented as

\[ \begin{aligned} &\sum_{i < j}\max(A_j - A_i, 0) \\ &= \sum_{i < j, A_i \leq A_j} A_j - A_i \\ &= (\text{The sum of }A_j\text{ such that }i < j\text{ and }A_i \leq A_j) \\ &- (\text{The number of }A_j\text{ such that }i < j\text{ and }A_i \leq A_j) \times A_i . \end{aligned} \]

Using this fact, this problem can be solved by the sweep line algorithm as follows.

  • Prepare a data structure that manages the following two values:
    • A multiset \(S_0\) that supports two kinds of query, insertion of an element and retrieval of the number of elements not less than \(x\).
    • A multiset \(S_1\) that supports two kinds of query, insertion of an element and retrieval of the sum of elements not less than \(x\).
  • Also, prepare a variable \(\mathrm{ans}\) that stores the answer. Initially, let \(\mathrm{ans} = 0\).
  • For each \(i = N, N-1, \dots, 2, 1\), perform the following.
    • Let \(c\) be the response to the query against \(S_0\) with \(x = A_i\).
    • Let \(s\) be the response to the query against \(S_1\) with \(x = A_i\).
    • Add \(s - c \times A_i\) to \(\mathrm{ans}\).
    • Insert \(A_i\) to \(S_0\) and \(S_1\).
  • Print the resulting value of \(\mathrm{ans}\).

\(S_0\) and \(S_1\) can be achieved by a Fenwick Tree with coordinate compression; they process query in \(\mathrm{O}(\log N)\) time each.

Therefore, the problem can be solved in a total of \(\mathrm{O}(N \log N)\) time, which is fast enough.

  • Sample code (Python)
from atcoder.fenwicktree import FenwickTree
import bisect

N = int(input())
A = list(map(int, input().split()))
B = sorted([x for x in set(A)])
M = len(B)
sum0 = FenwickTree(M)
sum1 = FenwickTree(M)
ans = 0
for i in reversed(range(N)):
    k = bisect.bisect_left(B, A[i])
    c = sum0.sum(k, M)
    s = sum1.sum(k, M)
    ans += s - c * A[i]
    sum0.add(k, 1)
    sum1.add(k, A[i])
print(ans)

posted:
last update: