noshi91さんの記事 を大いに参考にしています。実装に苦労したので載せます。 verify 用問題は現在 Library Checker で作業しています。
以下の問題を考えます。
整数 と長さ の非負整数列 が与えられます。非負整数列 で以下の条件を満たすものの個数を で求めてください。
ただし は保証されています。
条件式の は に を足せば に変えれるので、以降上の問題文の を に変えたものを考えます。
次に として良いのでこのように を変更します。実装上においては は が小さいものから、 は が大きいものから上の式を当てはめれば良いです。
また、 の全てを同じ数減らしても良いので、 とします。
まず、部分問題である の場合すなわち上限のみが指定されている場合を考えます。
を となる個数と考えるナイーブな DP を考えるとこの問題は で解けることがひとまず分かります。
この DP の遷移は を満たしているとして においては とし、 では です。
ここで、 DP の遷移をよく考えると を初期値として としても問題ないことが分かります。答えは なので、この問題は以下のように言い換えることができます。
平面において 上には の範囲で線分が引かれている。ただし 上には の範囲で線分が引かれている。線分がその上を通っている各格子点において、一つ右(すなわち 座標に 足したもの)の格子点にも線分が通っている場合、その 点間にも線分を引く。このようにしてできた図形において から への経路は何通りか。
例えば においては以下のような図形になります。
ここで、格子点の左上に書かれた赤い数字はいわゆる経路数え上げ DP の途中結果であり、特に は に書かれた赤い数字です。別の経路と数列の言い換えをすると、 回目に右に進んだ時の 座標の値が に対応しています。例えば青色の経路は に、緑色の経路は に対応しています。
これを分割統治で解くことを考えます。まず、問題を言い換えて一番下の線分には数字 が書かれている状態から始め、一番右の線分に書かれた赤い数字 個(以降 とします) を列挙するものとします。
具体例として は上の画像のままで とする場合を考えます。下の画像の通り がこの場合の答えです。
下辺の真ん中を通る縦線 で図形を分け、その縦線における値を求めます。上の画像なら であるので 上の赤い数字です。ただし、実装上ではここで再帰的に関数を呼び出しますがその際 の値(今回なら )を にして呼び出してください。今回なら返り値は となります。
まず、 を求めます。 による寄与、 における答え による寄与を考えると、以下の 種類の問題が解ければ良いです。イメージとしては下の画像のようになります。
数列 と が与えられる。 として、 を求めよ。
数列 と が与えられる。 として、 を求めよ。
まず について解きます。 と式変形できるので数列 として を convolution すると 番目から順に となります。
次に について解きます。 と式変形できます。ただし負の階乗がある部分については二項係数の関係上、適切に とする必要があります。数列 として(ただし の先頭の は 個です) を convolution すると 番目から順に となります。
以下が ACL を使用した C++ での実装例です。
実装例
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using namespace atcoder;
using mint = modint998244353;
vector<mint> fac,finv,inv;
// i!, 1/i!, 1/i
/*
適切に fac,finv,inv の前準備がされているものとする
*/
// v = (v_0, v_1, ... , v_{n-1}) に対し
// f(x) = sum_i ( n-1-i+xCx v_i ) とする
// {f(0), f(1), ... , f(m-1)} を返す
vector<mint> enumerate_f(int n,vector<mint> &v,int m){
vector<mint> _v=v;
reverse(_v.begin(),_v.end());
for(int i=0;i<n;i++)_v[i]*=finv[i];
reverse(_v.begin(),_v.end());
// _v = (v_{0}/(n-1)!, v_{1}/(n-2)!, ... , v_{n-1}/0!)
vector<mint> fsub(_n+m);
for(int i=0;i<n+m;i++)fsub[i]=fac[i];
vector<mint> f=convolution(_v,fsub),res;
for(int i=0;i<m;i++)res.emplace_back(f[_n-1+i]*finv[i]);
return res;
}
// v = (v_0, v_1, ... , v_{n-1}), l に対し
// g(x) = sum_i ( l-i+xCl v_i ) とする
// {g(0), g(1), ... , g(n-1)} を返す
vector<mint> enumerate_g(int n,vector<mint> &v,int l){
vector<mint> _v=v;
vector<mint> gsub(n-1,0);
for(int i=0;i<n;i++)gsub.emplace_back(fac[l+i]*finv[i]);
vector<mint> g=convolution(_v,gsub),res;
for(int i=0;i<n;i++)res.emplace_back(g[_n-1+i]*finv[l]);
return res;
}
ここで、右上部分の再帰のために 部分の値も求める必要があり、これも縦横を入れ替えて考えることで先ほどと同じようになります。ただし、 についての の値を求める際にメビウス変換を施す必要があります。なぜなら、 の部分は既に格子上の数え上げの求め方を行っているために、上側の値は下側の値の影響を受けており、このままでは同じ経路を複数回数え上げることになるためです。
(追記:
Mitarushi さんより、メビウス変換した後の結果は を整数として で区切った時の値とも考えられると指摘されました。その理解の方がわかりやすいかも知れません。 )
このメビウス変換については右上部分の再帰における引数についても行う必要があることに注意してください。
つまり上の例では を引数として左下部分を考え、 を得ます。その後これをメビウス変換して とします。次にこれと下の部分である から と の部分の値を求めます。これにより の部分は となります。
これもやはりメビウス変換して とし、これを引数部分にして右上部分を考えます。その結果 を得ます。 とまとめて全体としての結果は となり、特に求める問題の答えは です。
全体としての実装は以下のようになります。
実装例
// 長さ n の(広義単調増加な)非負整数列 a に対し以下を考える
// 横の長さを n とする辺がある 左から i の地点から a[i] だけ上に伸びる辺がある
// 上手い具合に横線が引いてある
// また下辺の左から i の地点には start[i] が書かれている
// この状態でマス目数え上げ DP をしたとき、一番右の辺に書かれた数字 a[n-1] 個を返す
vector<mint> sub(int n,vector<int> &a,vector<mint> &start){
int m=a[n-1];
vector<mint> res(m+1);
if(n==1){
for(int i=0;i<m+1;i++)res[i]=start[0];
return res;
}
if(n==2){
for(int i=0;i<m+1;i++)res[i]=start[0]*min(i+1,a[0]+1)+start[1];
return res;
}
//n > 2
int mid=n/2;
int m_front=a[mid];
vector<int> a_front(mid+1),a_back(n-mid);
vector<mint> start_front(mid+1),start_end(n-mid,0);
for(int i=0;i<mid+1;i++){
a_front[i]=_a[i];
start_front[i]=start[i];
if(i==mid)start_front[i]=0;
}
for(int i=mid;i<n;i++){
a_back[i-mid]=a[i]-m_front;
start_end[i-mid]=start[i];
}
vector<mint> sub_front=sub(mid+1,a_front,start_front);
// sub_front は長さ m_front+1
for(int i=m_front;i>=1;i--)sub_front[i]-=sub_front[i-1];
vector<mint> sub_front_f=enumerate_f(m_front+1,sub_front,n-mid),sub_front_g=enumerate_g(m_front+1,sub_front,n-mid-1);
vector<mint> start_end_f=enumerate_f(n-mid,start_end,m_front+1),start_end_g=enumerate_g(n-mid,start_end,m_front);
for(int i=0;i<m_front;i++)res[i]=sub_front_g[i]+start_end_f[i];
for(int i=0;i<n-mid;i++)start_end[i]=sub_front_f[i]+start_end_g[i];
for(int i=n-mid-1;i>=1;i--)start_end[i]-=start_end[i-1];
vector<mint> sub_end=sub(n-mid,a_back,start_end);
for(int i=0;i<(int)sub_end.size();i++)res[i+m_front]=sub_end[i];
return res;
}
全体の計算量は です。
の制約がない場合を考えます。
まず、先ほどと同じように図形上の経路数え上げに帰着させます。 の場合において の場合も数える必要があることに注意すると、以下のような図形を考えれば良いです。
平面において 上には の範囲で線分が引かれている。ただし 上には の範囲で、 上には の範囲で線分が引かれている。線分がその上を通っている各格子点において、一つ右(すなわち 座標に 足したもの)の格子点にも線分が通っている場合、その 点間にも線分を引く。このようにしてできた図形において から への経路は何通りか。
から右に進めるだけ進み、突き当たりでは上に進めるだけ進み、と繰り返し に到達する経路を考えます。
例えば では以下のようになります。
となります。ここで、この経路は複数の線分の結合として書けます。
上なら格子点の列 について、隣接する 点を結ぶ線分からなるものが赤い経路です。
ここで、線分から線分へ書かれた数字を求めるのは先ほど説明した がある場合の問題(上の実装例では sub
に相当)を解けば良いです。この場合では から と求めていきます。
ただし、上でも説明したように適宜メビウス変換を施してから上で説明した問題(上の実装例では sub
に相当)を解く処理を行わせることに注意してください。
実装例
// https://noshi91.hatenablog.com/entry/2023/07/21/235339
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using namespace atcoder;
#include <iostream>
// 以下を満たす "広義" 単調増加の整数列 x を数える
// a_i <= x_i <= b_i
// a,b の単調性は要求しない(内部で補正する)
// ただし、 a,b が非負整数列であることは要求される
struct number_of_increasing_sequences_between_two_sequences{
private:
using ll = long long;
using mint = static_modint<998244353>;
#define all(a) a.begin(),a.end()
#define rep(i,start,end) for(ll i=start;i<(ll)(end);i++)
#define per(i,start,end) for(ll i=start;i>=(ll)(end);i--)
int n;
vector<int> a,b;
int zelo_flg=0;
long long mod = 998244353;
vector<mint> fac,finv,inv;
// i!, 1/i!, 1/i
// v = (v_0, v_1, ... , v_{n-1}) に対し
// f(x) = sum_i ( n-1-i+xCx v_i ) とする
// {f(0), f(1), ... , f(m-1)} を返す
vector<mint> enumerate_f(int _n,vector<mint> &v,int m){
// 省略
}
// v = (v_0, v_1, ... , v_{n-1}), l に対し
// g(x) = sum_i ( l-i+xCl v_i ) とする
// {g(0), g(1), ... , g(n-1)} を返す
vector<mint> enumerate_g(int _n,vector<mint> &v,int l){
// 省略
}
// 長さ _n の(広義単調増加な)非負整数列 _a に対し以下を考える
// 横の長さを _n とする辺がある 左から i の地点から _a[i] だけ上に伸びる辺がある
// 上手い具合に横線が引いてある
// また下辺の左から i の地点には start[i] が書かれている
// この状態でマス目数え上げ DP をしたとき、一番右の辺に書かれた数字 _a[_n-1] 個を返す
vector<mint> sub(int _n,vector<int> &_a,vector<mint> &start){
// 省略
}
public:
number_of_increasing_sequences_between_two_sequences() = default;
number_of_increasing_sequences_between_two_sequences(int _n,vector<int> _a,vector<int> _b){
n=_n,a=_a,b=_b;
rep(i,0,n){
if(a[i]>b[i])zelo_flg=1;
if(i>0&&b[i]<a[i-1])zelo_flg=1;
}
per(i,n-2,0)b[i]=min(b[i],b[i+1]);
rep(i,1,n)a[i]=max(a[i],a[i-1]);
int al=a[n-1],bl=b[n-1];
per(i,n-1,1)a[i]=min(a[i],a[i-1]);
int dec=a[0];
rep(i,0,n)a[i]-=dec;
rep(i,0,n)b[i]-=dec;
n++;
a.emplace_back(al-dec);
b.emplace_back(bl+5-dec);
rep(i,0,n){
if(a[i]>b[i])zelo_flg=1;
if(i>0&&b[i]<a[i-1])zelo_flg=1;
}
int m=max((n+max(a[n-1],b[n-1]))*2+10,100);
fac.resize(m);finv.resize(m);inv.resize(m);
fac[0]=fac[1]=1;
finv[0]=finv[1]=1;
inv[1]=1;
rep(i,2,m){
fac[i]=fac[i-1]*i;
inv[i]=-inv[mod%i]*(mod/i);
finv[i]=finv[i-1]*inv[i];
}
}
void debug_a(){
for(int val:a)cout<<val<<" ";
cout<<endl;
}
void debug_b(){
for(int val:b)cout<<val<<" ";
cout<<endl;
}
// 以下を満たす "広義" 単調増加の整数列 x の個数を返す
// a_i <= x_i <= b_i
mint answer(){
if(zelo_flg)return 0;
if(n==1)return (mint)b[0]-a[0]+1;
int dist=upper_bound(all(a),a[0])-a.begin();
// [0,dist) までは a_i = a_0
int px=0,py=a[0];
int qx=dist-1,qy=a[0];
if(qx==0)qy=b[0];
vector<mint> now(abs(qx-px)+abs(qy-py)+1,0);
now[0]=1;
while(qx!=n-1||qy!=b[n-1]){
int sz=now.size();
per(i,sz-1,1){
if(i==1&&px==0&py==0)break;
now[i]-=now[i-1];
}
if(py==qy){
// 上に伸ばす
vector<int> _a(qx-px+1);
rep(i,0,qx-px+1)_a[i]=b[px+i]-py;
now=sub(qx-px+1,_a,now);
px=qx,py=qy;
qy=b[qx];
}
else{
// 右に伸ばす
int index=upper_bound(all(a),qy)-a.begin();
// (px,py), (qx=px,qy) ->
// (qx,qy), (index,qy)
vector<int> _a(qy-py+1);
rep(i,0,qy-py+1){
int index2=upper_bound(all(a),py+i)-a.begin();
_a[i]=index2-px;
}
per(i,qy-py,0)_a[i]-=_a[0];
now=sub(qy-py+1,_a,now);
px=qx,py=qy;
qx=index-1;
}
}
return now[now.size()-1];
}
};
using ll = long long;
#define rep(i,start,end) for(ll i=start;i<(ll)(end);i++)
using mint = modint998244353;
mint naive(int n,vector<int> a,vector b){
int mx=3000;
mint dp[n][mx]={};
rep(i,a[0],b[0]+1)dp[0][i]=1;
rep(i,1,n){
rep(j,a[i],b[i]+1){
rep(k,0,mx){
if(j<k)break;
dp[i][j]+=dp[i-1][k];
}
}
}
mint res=0;
rep(i,0,mx)res+=dp[n-1][i];
return res;
}
int main(){
int n;cin>>n;
vector a(n),b(n);
rep(i,0,n)cin>>a[i];
rep(i,0,n)cin>>b[i];
number_of_increasing_sequences_between_two_sequences num(n,a,b);
//num.debug_a();
//num.debug_b();
//cout<<endl;
cout<<num.answer().val()<<endl;
cout<<naive(n,a,b).val()<<endl;
}
経路の長さは経路によらず であることを考えると全体の計算量は です。