标签:lld pac print turn 问题 question void rip bit
有 \(n\) 中烹饪方法和 \(m\) 种食材,要求:
求做菜的方案数。
先容斥一下,答案为忽略第三个条件所得的方案数减去每一种食材超过一半的方案数之和。
忽略掉第三个条件之后答案显然是
\[ \prod_{i=1}^n(1+\sum_{j=1}^m a_{i,j})-1 \]
减去 1 是去掉一道菜都不做的方案。
枚举每一列超过一半的情况,显然,除这一列外,其他 \(n-1\) 列是一样的。那么对于第 \(col\) 列,设 \(f_{i,j,k}\) 表示前 \(i\) 行,第 \(col\) 列选 \(j\) 个且其他列选 \(k\) 个的方案数。则:
\[ f_{i,j,k} = f_{i-1,j,k}\text{(不选)}+a_{i,col}*f_{i-1,j-1,k}+(s_i-a_{i,col})*f_{i,-1,j,k-1} \]
此时的复杂度是 ,\(O(m)\) 的枚举 \(col\) * \(O(n^3)\) 的 \(DP\), = \(O(mn^3)\) ,可以得到 84pts
的好成绩了
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 101;
const int M = 2001;
const int mod = 998244353;
ll n,m;
ll s[N],a[N][M],f[N][N][N];
ll ans=1;
void init()
{
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;++i)
{
for(int j=1;j<=m;++j)
{
scanf("%lld",&a[i][j]);
s[i]=(s[i]+a[i][j])%mod;
}
ans=(ans*(s[i]+1))%mod;
}
ans=(mod-1+ans)%mod;
}
int main()
{
init();
for(int col=1;col<=m;++col)
{
memset(f,0,sizeof(f));
f[0][0][0]=1;
for(int i=1;i<=n;++i)
{
for(int j=0;j<=i;++j)
{
for(int k=0;k<=i-j;++k)
{
f[i][j][k]=f[i-1][j][k]+f[i-1][j-1][k]*a[i][col]+f[i-1][j][k-1]*(s[i]-a[i][col]);
f[i][j][k]=(f[i][j][k]%mod+mod)%mod;
}
}
}
for(int j=1;j<=n;++j)
{
for(int k=0;k<=n-j;++k)
{
if(k<j) ans=((ans-f[n][j][k])%mod+mod)%mod;
}
}
}
printf("%lld\n",ans);
return 0;
}
然后我们发现我们并不关心j和k的具体值。我们只关心他们的差。所以我们可以把后两维压缩成一维。
设 \(f_{i,j}\) 表示前 \(i\) 行,第 \(col\) 列比其他列多选 \(j\) 个的方案数。则:
\[ f_{i,j} = f_{i-1,j}\text{(不选)}+a_{i,col}*f_{i-1,j-1}+(s_i-a_{i,col})*f_{i,-1,j+1} \]
此时的复杂度是 ,\(O(m)\) 的枚举 \(col\) * \(O(n^2)\) 的 \(DP\), = \(O(mn^2)\) ,可以得到 100pts
的好成绩了
这里有一个小技巧就是把每个j都加上n,避免数组负下标的出现。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 101;
const int M = 2001;
const int mod = 998244353;
ll n,m;
ll s[N],a[N][M],f[N][N*2];
ll ans=1;
void init()
{
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;++i)
{
for(int j=1;j<=m;++j)
{
scanf("%lld",&a[i][j]);
s[i]=(s[i]+a[i][j])%mod;
}
ans=(ans*(s[i]+1))%mod;
}
ans=(mod-1+ans)%mod;
}
int main()
{
init();
for(int col=1;col<=m;++col)
{
memset(f,0,sizeof(f));
f[0][n]=1;
for(int i=1;i<=n;++i)
{
for(int j=n-i;j<=n+i;++j)//注意dp的范围!
{
f[i][j]=f[i-1][j]+f[i-1][j-1]*a[i][col]+f[i-1][j+1]*(s[i]-a[i][col]);
f[i][j]=(f[i][j]%mod+mod)%mod;
}
}
for(int j=1;j<=n;++j)
{
ans=((ans-f[n][n+j])%mod+mod)%mod;
}
}
printf("%lld\n",ans);
return 0;
}
DP的取值范围问题还是不清楚。
标签:lld pac print turn 问题 question void rip bit
原文地址:https://www.cnblogs.com/oierwyh/p/12267569.html