Official

Ex - Group Photo Editorial by yuto1115

解説

以下、\(A,B\) はそれぞれソートされていることを仮定します。

まず、数列 \(c = (c_1,c_2,\dots,c_{N+1})\) を以下のように定義します。

  • \(c_i=\min\{a_{i-1},a_i\}\ (2 \leq i \leq N)\)
  • \(c_1=a_1\)
  • \(c_{N+1}=a_N\)

このとき、「良い並び方である」ことと「すべての \(i\ (1 \leq i \leq N+1)\) について \(b_i > c_i\) である」ことは同値です。 よって、\(c\) を昇順にソートした配列を \(C\) とおけば、「後列の並び方をうまく定めることで良い並び方にできる」ことは「すべての \(i\ (1 \leq i \leq N+1)\) について \(B_i > C_i\) である」ことと同値です。

\(A_1,A_2,\dots,A_N\) の順に相対的な位置関係を決めていく挿入 DP を考えます。ただし、\(A_1,A_2,\dots,A_i\) の位置関係を決めた段階で隣接している各要素のペアについて、「最終的な \(A\) においても隣接している」のか「\(A_{i+1}\) 以降の要素が間に挟まる」のかを区別します。そして、「最終的な \(A\) においても隣接している」要素のペア間に辺を張ったときの連結成分数を、DP の第 \(2\) のキーとして管理します。遷移の際(すなわち \(A_{i+1}\) を挿入する際)は、

  • \(A_{i+1}\) が単独で新たな連結成分を成す
  • 既に存在する連結成分の右端または左端に \(A_{i+1}\) がくっつく
  • 既に存在する \(2\) つの連結成分の間に \(A_{i+1}\) が挟まり、それぞれの連結成分と \(A_{i+1}\) がくっつくことで連結成分数が \(1\) 減る

\(3\) パターンを考える必要がありますが、そのいずれにおいても遷移の際にかかる係数は簡単な形で表せます。

最後に、連結成分数という情報さえ持っていれば「後列の並び方をうまく定めることで良い並び方にできる」かどうかを判定できることを確認します。

\(A_1,A_2,\dots,A_i\) まで決めた段階で連結成分数が \(j\) であるという情報からは、\(C\) のうち \(A_i\) 以下である要素は \(i+j\) 個あるという情報が得られます。この情報は条件を判定するのに十分です。具体的には、「\(A_i\) まで決めて、\(C\) のうち \(A_i\) 以下である要素の数は \(k\) 個」という状態から「\(A_{i+1}\) まで決めて、\(C\) のうち \(A_{i+1}\) 以下である要素の数は \(l\) 個」という状態へ遷移する際に、\(B_{k+1},B_{k+2},\dots,B_l\) がすべて \(A_{i+1}\) 以上であるかどうかをチェックすればよいです。

よって本問題を \(O(N^2)\) で解くことができました。以下の図や実装例も参考にしてください。

図の説明:上から下に向かって DP が遷移する様子を表しています。\(i\) は今までに挿入した要素の数、\(j\) は連結成分数、\(k\ (=i+j)\)\(C\) のうち \(A_i\) 以下である要素の数です。\(C\) のうち \(A_i\) 以下である \(k\) 個の要素の位置が V 印で示されています(\(a_p\)\(a_{p+1}\) の間にある V 印が \(C_{p+1}\) を表します)。

実装例 (C++) :

#include<bits/stdc++.h>
#include<atcoder/modint>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

int main() {
    int n;
    cin >> n;
    vector<int> a(n), b(n + 1);
    for (int &i: a) cin >> i;
    for (int &i: b) cin >> i;
    sort(a.begin(), a.end());
    sort(b.begin(), b.end());
    vector dp(n + 1, vector<mint>(n + 1));
    dp[0][0] = 1;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j <= i; j++) {
            for (int p = 0; p < 3; p++) {
                if (p > j) break;
                if (i + j + 2 - p > n + 1) continue;
                if (p <= 1 and b[i + j] < a[i]) continue;
                mint coef;
                if (p == 0) coef = j + 1;
                else if (p == 1) coef = 2 * j;
                else coef = j - 1;
                dp[i + 1][j + 1 - p] += dp[i][j] * coef;
            }
        }
    }
    cout << dp[n][1].val() << endl;
}

posted:
last update: