suta精進記

競プロ関連体験記や精進記録などを書きます

Educational Codeforces Round 97 F : Emotional Fisherman

問題リンク :

codeforces.com

問題概要

長さ$N$の数列$A$が与えられ、これの並べ替えを考える。
左から見た時に、それより前までのmaxを記録して現在見ている数字がmaxの2倍以上またはmaxの半分以下であるような数列の並べ替えの総数をmod $998244353$で求めよ。
ただし一番左の要素を見ているときのmaxは0とする。
例えば1 1 4 9は2番目を見た時に条件を満たしていない。4 9 1 1は条件を満たす。

解法

以降解法では全て1-indexで考えます。
挿入dpのような考え方をします。
数列を昇順にソートしてから、状態として

  • 何番目を見ているか
  • max(をとるindex)

をもつことを考えます。すなわち $$ dp[i][j] = i番目まで見てmaxがa_jであるときの場合の数 $$ として動的計画法を行うことを考えます。

その前に、前処理として以下の値を求めておきます。

  • $pre[j] = a_j \geqq 2a_k をみたすkの最大値$

数列は昇順なので、$k$以下の$a$の値は全て$a_j$の半分以下であることが保証されます。

dpにより$i+1$番目の数を決める時に、考えるべき遷移は2つです。

① $a_j$ よりも小さな値を挿入するとき
$i$番目まで見たという事実から、$i$番目までの数は全て$a[pre[j]]$よりも小さいということが確定しています。なぜならmaxをとるindexは昇順に更新されるからです。
これによってこれまで何を選んだかを考える必要がなくなり、$pre[j]$以下の値がまだ何個使用されていないかを係数とすることで更新が可能になります。
そして、$a_j$よりも小さな値を挿入するとき、maxは再び$a_j$となります。
よって更新式は以下のようになります。 $$ dp[i+1][j]= dp[i][j] * (pre[j] - i + 1) $$ $j$番目の数が$i$個の中に1つ含まれるため(maxが$a[j]$なのでそれはそう)、+1が入ることに注意します。
$pre[j]$は1-indexedであるため、0-indexedで実装する際は+2となります。

② $a_j$ よりも大きな値を挿入するとき
数列は昇順なので、大きな値として$a_k$を取ることにすると$j<k$となります。
これを貰うdpとして考えると、次にmaxをとるindexとして$k$を選ぶとき、それまでのmaxの2倍以上でなくてはならないので$pre[k]$以下の場合の数を足しこむ形になります。
よって更新式は以下のようになります。 $$ dp[i+1][k] = \sum_{l=1}^{pre[k]} dp[i][l] $$

高速化

このdpは状態数が$O(N^{2})$、①の方の遷移が$O(1)$、②の方の遷移が$O(N)$となっており、全体の計算量は$O(N^{3})$です。
しかし、②の遷移は累積和を使用することで$O(1)$で更新できるようになります。

以上で$O(N^{2})$でこの問題を解くことができます。

実装

#include <iostream>
#include <algorithm>
using namespace std;
const int mod = 998244353;

void chmax(int &a, int b) {if(a < b) a = b;}

int a[5050], pre[5050];
long long dp[5050], sdp[5050];

int main(){
    int n; cin >> n;
    for(int i=0; i<n; i++) cin >> a[i];
    sort(a,a+n);
    for(int i=0; i<n; i++) pre[i] = -1;
    for(int i=0; i<n; i++){
        for(int j=i+1; j<n; j++){
            if(a[j] >= 2*a[i]) chmax(pre[j],i);
        }
    }
    for(int i=0; i<=n; i++) sdp[i] = 1;
    for(int i=0; i<n; i++){
        for(int j=0; j<n; j++){
            if(pre[j] >= i-1) dp[j+1] = dp[j+1] * (pre[j] - i + 2) % mod;
            else dp[j+1] = 0;
            (dp[j+1] += sdp[pre[j]+2]) %= mod;
        }
        sdp[0] = 0;
        for(int j=0; j<=n; j++){
            sdp[j+1] = (sdp[j] + dp[j]) % mod;
        }
    }
    cout << dp[n] << "\n";
}

感想

累積和dpの添え字はいつも悩みますね
遷移を完全に把握した状態で書き始めないとでたらめな実装になりがち