[Luogu4705] 玩游戏
Description
给定两个长度分别为 \(n\) 和 \(m\) 的序列 \(a\) 和 \(b\)。要从这两个序列中分别随机一个数,设为 \(a_x,b_y\),定义该次游戏的 \(k\) 次收益为 \((a_x+b_y)^k\) 。对于 \(i=1,2,\dots,t\),求一次游戏 \(i\) 次收益的期望。\(n,m,t\leq 10^5\)。
Sol
根据期望的线性性,显然可以求每个点对的 \(i\) 次收益,最后再除以 \(nm\) 就好了。
所以问题转化为,对于每个 \(k\),求:
\[\sum_{i=1}^n\sum_{j=1}^m (a_i+b_j)^k \]接下来直接推导:
\[\begin{aligned} ans_k&=\sum_{i=1}^n\sum_{j=1}^m (a_i+b_j)^k\\ &=\sum_{i=1}^n\sum_{j=1}^m\sum_{p=0}^k \binom kpa_i^pb_j^{k-p}\\ &=\sum_{p=0}^k\binom kp \left(\sum_{i=1}^na_i^p\right) \left(\sum_{j=1}^mb_j^{k-p} \right)\\ &=k!\cdot\sum_{p=0}^k \left(\sum_{i=1} ^n \frac{a_i^p}{p!}\right) \left(\sum_{j=1}^m\frac{b_j^{k-p}}{(k-p)!} \right) \end{aligned} \]发现这是个卷积式子,现在问题变成了如何求:
\[\sum_{i=1}^n a_i^p \]设 \(F(x)=\prod\limits_{i=1}^n(1+a_ix),G(x)=\ln(F(x))\)
那么:
\[\begin{aligned} G(x)&=\ln(\prod_{i=1}^n 1+a_ix)\\ &=\sum_{i=1}^n \ln(1+a_ix) \end{aligned} \]把 \(\ln(1+a_ix)\) 泰勒展开:
\[\begin{aligned} G(x)&=\sum_{i=1}^n \ln(1+a_ix)\\ &= \sum_{i=1}^n \sum_{k=1}^\infty \frac{(-1)^{k+1}}{k}\cdot a_i^k\cdot x^k\\ &= \sum_{k=1}^\infty \frac{(-1)^{k+1}}k\cdot x^k\cdot \left( \sum_{i=1}^n a_i^k \right) \end{aligned} \]后边那项就是我们要求的。
总结一下,先分治\(\text{NTT}\)求出\(F(x)\),再取对数求出\(G(x)\),然后第 \(k\) 项乘上一个系数就是 \(\sum\limits_{i=1}^n a_i^k\) 了。
Code
#pragma GCC optimize(2)
#include
using namespace std;
typedef double db;
typedef long long ll;
typedef vector vec;
const int N=262144+5;
const int mod=998244353;
#define pb push_back
int w[2][N],in[N];
int fac[N],ifac[N],A[N],B[N];
int n,m,t,a[N],b[N],c[N],d[N];
int lim,maxn,rev[N],tmpa[N],tmpb[N];
int ksm(int a,int b=mod-2,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
}
void ntt(int *f,int g){
for(int i=1;i=mod?x+y-mod:x+y,f[j+k+mid]=x-y<0?x-y+mod:x-y;
}
}
} if(g)
for(int i=0;i>1;
vec L=calc(a,l,mid),R=calc(a,mid+1,r);
lim=1;while(lim<=r-l+1) lim<<=1;
for(int i=1;i>1]>>1)|(i&1?lim>>1:0);
for(int i=0;i<(int)L.size();i++) A[i]=L[i];
for(int i=0;i<(int)R.size();i++) B[i]=R[i];
ntt(A,0),ntt(B,0);
for(int i=0;i>1); lim=len<<1;
for(int i=1;i>1]>>1)|(i&1?lim>>1:0);
for(int i=len;i>1]>>1)|(i&1?lim>>1:0);
ds(a,tmpb,n);
ntt(tmpa,0),ntt(tmpb,0);
for(int i=0;i>1]>>1)|(i&1?lim>>1:0);
ntt(a,0),ntt(b,0);
for(int i=0;i