题目大意
有一个\(n\)个点\(m\)条边的图,每条边有一种颜色\(c_i\in\{1,2,3\}\),求所有的包括\(i\)条颜色为\(1\)的边,\(j\)条颜色为\(2\)的边,\(k\)条颜色为\(3\)的边的生成树的数量。
对\({10}^9+7\)取模。
\(n\leq 50\)
题解
如果\(\forall i,c_i=1\),就可以直接用基尔霍夫矩阵计算生成树个数。但是现在有三种颜色,不妨设\(c_i=2\)的边的边权为\(x\),\(c_i=3\)的边的边权为\(y\)。因为\(x\)的次数不会超过\(n-1\),所以可以设\(y=x^n\)。基尔霍夫矩阵可能是这样子的:
\[\begin{bmatrix}1+x&-1&-x\\-1&1+x^n&-x^n\\-x&-x^n&x^{n+1}\end{bmatrix}\]
这样的话直接高斯消元很明显会TLE,可以用FFT优化。FFT是在每次乘法时先做一次求值,做一次点值乘法,最后做一次插值。所以我们可以先在消元前做一次求值,消元时直接做点值乘法,最后再插值插回来。
因为答案的最高次不超过\(n(n-1)\),所以我们可以把\(n(n-1)+1\)个点带到行列式中,把每次得到的行列式保存下来,最后再用拉格朗日插值插回来。这里不能用高斯消元是因为高斯消元会直接TLE。
求行列式的总时间复杂度是\(O(n^2)\times O(n^3)=O(n^5)\),拉格朗日插值的时间复杂度是\(O({(n^2)}^2)=O(n^4)\),高斯消元的时间复杂度是\(O({(n^2)}^3)=O(n^6)\)。
代码
#include#include #include #include #include #include using namespace std;typedef long long ll;typedef pair pii;ll p=1000000007;ll a[60][60];int lx[10010];int ly[10010];int lc[10010];ll x[3010];ll y[3010];ll c[3010];ll b[3010];ll d[3010];ll l[3010];ll ans[3010];ll fp(ll a,ll b){ ll s=1; while(b) { if(b&1) s=s*a%p; a=a*a%p; b>>=1; } return s;}ll calc(int t){ int i,j,k; ll s=1; for(i=1;i<=t;i++) { for(j=i;j<=t;j++) if(a[j][i]) break; if(j>t) return 0; int x=j; if(x>i) { s=-s; for(j=i;j<=t;j++) swap(a[i][j],a[x][j]); } for(j=i+1;j<=t;j++) if(a[j][i]) { ll d=a[j][i]*fp(a[i][i],p-2)%p; for(k=i;k<=t;k++) a[j][k]=(a[j][k]-a[i][k]*d%p)%p; } } for(i=1;i<=t;i++) s=s*a[i][i]%p; s=(s+p)%p; return s;}int main(){// freopen("count.in","r",stdin); int n,m; scanf("%d%d",&n,&m); int i,j,k; for(i=1;i<=m;i++) scanf("%d%d%d",&lx[i],&ly[i],&lc[i]); for(i=1;i<=n*n;i++) { x[i]=(i*1000+1)%p;// x[i]=i%p; ll px=fp(x[i],n); memset(a,0,sizeof a); for(j=1;j<=m;j++) { if(lc[j]==1) { a[lx[j]][lx[j]]++; a[ly[j]][ly[j]]++; a[lx[j]][ly[j]]--; a[ly[j]][lx[j]]--; } else if(lc[j]==2) { (a[lx[j]][lx[j]]+=x[i])%=p; (a[ly[j]][ly[j]]+=x[i])%=p; (a[lx[j]][ly[j]]-=x[i])%=p; (a[ly[j]][lx[j]]-=x[i])%=p; } else { (a[lx[j]][lx[j]]+=px)%=p; (a[ly[j]][ly[j]]+=px)%=p; (a[lx[j]][ly[j]]-=px)%=p; (a[ly[j]][lx[j]]-=px)%=p; } } y[i]=calc(n-1); } int t=n*n; memset(c,0,sizeof c); c[0]=1; memset(ans,0,sizeof ans); for(i=1;i<=t;i++) for(j=t;j>=0;j--) { (c[j+1]+=c[j])%=p; (c[j]=-c[j]*x[i])%=p; } for(i=1;i<=t;i++) { memcpy(d,c,sizeof d); memset(b,0,sizeof b); for(j=t;j>=0;j--) { b[j]=d[j+1]; d[j]=(d[j]+d[j+1]*x[i])%p; d[j+1]=0; } ll s=0,px=1; for(j=0;j<=t;j++) { s=(s+px*b[j])%p; px=px*x[i]%p; } s=fp(s,p-2)*y[i]%p; for(j=0;j<=t;j++) b[j]=b[j]*s%p; for(j=0;j<=t;j++) ans[j]=(ans[j]+b[j])%p; } for(i=0;i