码迷,mamicode.com
首页 > 其他好文 > 详细

[北京省选集训2019]生成树计数

时间:2020-07-04 15:30:15      阅读:61      评论:0      收藏:0      [点我收藏+]

标签:http   const   表示   get   base   一次函数   减法   lang   cpp   

题目

点这里看题目。

分析

考察一下矩阵树定理的基本式子:

\[\sum_T \prod_{e\in T} w_e \]

\(v(T)\)\(T\)的权值,我们发现,\(v(T)\)应该是\(T\)中的边的“某种意义”下的

这意味着,我们只需要能够保证\(v(T)\)的贡献可分割,便可以定义一个存在基础四则运算的“类型”,满足该“类型”的积就相当于合并了边的贡献。利用该“类型”,我们就可以用矩阵树定理计算所有生成树的贡献。

例如,一般矩阵树定理,利用(带模)整数的运算即可;

再例如,求边权和的矩阵树定理,利用“一次函数”的“积”,我们求出了系数的“和”。

对于此题,我们考虑将“合并贡献”作为切入点,即考虑已知\(a^k,b^k\),如何求出\((a+b)^k\)

当然这样没法做,我们考虑用二项式定理展开:

\[(a+b)^k=\sum_{i=0}^k\binom{k}{i}a^ib^{k-i} \]

这样需要用到较低次幂,不过没有关系,我们同样可以维护一下较低次幂,合并方法类似。

仔细观察你会发现这是指数生成函数的卷积,不过在这道题里面没有用。

然后我们只需要用长为\(k+1\)的向量\(\vec{z}=\begin{bmatrix}a^0&a^1&a^2&\dots&a^k\end{bmatrix}\)表示贡献即可。

加法和减法:不再赘述。

乘法:“二项式卷积”。

求逆元:可以次数从低到高进行递推,不再赘述。

单次乘法\(O(k^2)\),总时间复杂度\(O(n^3k^2)\)

代码

#include <cstdio>
#include <iostream>

const int mod = 998244353;
const int MAXN = 35;

template<typename _T>
void read( _T &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > ‘9‘ || s < ‘0‘ ){if( s == ‘-‘ ) f = -1; s = getchar();}
	while( s >= ‘0‘ && s <= ‘9‘ ){x = ( x << 3 ) + ( x << 1 ) + ( s - ‘0‘ ), s = getchar();}
	x *= f;
}

template<typename _T>
void write( _T x )
{
	if( x < 0 ){ putchar( ‘-‘ ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + ‘0‘ );
}

int C[MAXN][MAXN], invC[MAXN][MAXN];
int K;

int inv( const int );
int qkpow( int, int );
void add( int&, const int );
void sub( int&, const int );

struct vec
{
	int num[MAXN] = {};
	
	//h(i)=??{j=0...n}f(j)*g(i-j)*C(i,j) 
	
	vec() {}
	vec( const int b ) { num[0] = b; }
	
	int& operator [] ( const int indx ) { return num[indx]; }
	
	vec inver() const
	{
		vec ret;
		ret[0] = inv( num[0] );
		for( int i = 1, tmp ; i <= K ; i ++ )
		{
			tmp = 0;
			for( int j = 0 ; j < i ; j ++ )
				add( tmp, 1ll * ret[j] * num[i - j] % mod * C[i][j] % mod );
			ret[i] = 1ll * ( mod - tmp ) % mod * ret[0] % mod;
		}
		return ret;
	}
	
	vec operator + ( vec g ) const
	{
		vec ret;
		for( int i = 0 ; i <= K ; i ++ )
			add( ret[i], num[i] ), add( ret[i], g[i] );
		return ret;
	}
	
	vec operator - ( vec g ) const 
	{
		vec ret;
		for( int i = 0 ; i <= K ; i ++ )
			add( ret[i], num[i] ), sub( ret[i], g[i] );
		return ret;
	}

	vec operator * ( vec g ) const 
	{
		vec ret;
		for( int i = 0 ; i <= K ; i ++ )
			for( int j = 0 ; j <= i ; j ++ )
				add( ret[i], 1ll * num[j] * g[i - j] % mod * C[i][j] % mod );
		return ret;
	}
	
	vec operator / ( vec g ) const { return g * inver(); }
	void operator += ( vec g ) { *this = *this + g; }
	void operator -= ( vec g ) { *this = *this - g; }
	void operator *= ( vec g ) { *this = *this * g; }
	void operator /= ( vec g ) { *this = *this / g; }

	operator bool() const 
	{
		bool ret = false;
		for( int i = 0 ; i <= K ; i ++ )
			ret |= num[i]; 
		return ret; 
	} 
};

vec D[MAXN][MAXN], G[MAXN][MAXN], Kich[MAXN][MAXN];
int N;

int qkpow( int base, int indx )
{
	int ret = 1;
	while( indx )
	{
		if( indx & 1 ) ret = 1ll * ret * base % mod;
		base = 1ll * base * base % mod, indx >>= 1;
	}
	return ret;
}

int inv( const int a ) { return qkpow( a, mod - 2 ); }
void sub( int &x, const int v ) { x = ( x < v ? x - v + mod : x - v ); }
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }

int det( vec T[][MAXN], const int n )
{
	vec ans = 1, tmp, inver; int indx;
	for( int i = 1 ; i <= n ; i ++ )
	{
		indx = -1;
		for( int j = i ; j <= n ; j ++ )
			if( T[j][i] ) 
			{ indx = j; break; }
		if( indx == -1 ) return 0;
		if( indx ^ i ) 
			for( int k = 0 ; k <= K ; k ++ )
				ans[k] = mod - ans[k];
		std :: swap( T[i], T[indx] );
		inver = T[i][i].inver();
		for( int j = i + 1 ; j <= n ; j ++ )
			if( T[j][i] )
			{
				tmp = T[j][i] * inver;
				for( int k = i ; k <= n ; k ++ )	
					T[j][k] -= T[i][k] * tmp;
			}
		ans *= T[i][i];
	}
	return ans[K];
}

int main()
{
	int w;
	read( N ), read( K );
	for( int i = 0 ; i <= K ; i ++ )
	{
		C[i][0] = C[i][i] = 1;
		for( int j = 1 ; j < i ; j ++ )
			add( C[i][j], C[i - 1][j] ), add( C[i][j], C[i - 1][j - 1] );
	}
	int t;
	for( int i = 1 ; i <= N ; i ++ )
		for( int j = 1 ; j <= N ; j ++ )
		{
			read( w ), t = 1;
			for( int k = 0 ; k <= K ; k ++ )
				G[i][j][k] = t, t = 1ll * t * w % mod;
		}
	for( int i = 1 ; i <= N ; i ++ )
		for( int j = 1 ; j <= N ; j ++ )
			D[i][i] += G[i][j];
	for( int i = 1 ; i <= N ; i ++ )
		for( int j = 1 ; j <= N ; j ++ )
			Kich[i][j] = D[i][j] - G[i][j];
	write( det( Kich, N - 1 ) ), putchar( ‘\n‘ );
	return 0;
}

[北京省选集训2019]生成树计数

标签:http   const   表示   get   base   一次函数   减法   lang   cpp   

原文地址:https://www.cnblogs.com/crashed/p/13234746.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!