做题记录 3


Digit Tree

来源:CF715C, 2800

路径问题考虑用点分治进行处理.  

考虑当前的分治中心 $\mathrm{x}$, 对于两条路径 $\mathrm{d1,d2}$ 如何合并.  

令 $\mathrm{d1,d2}$ 分别表示从下到上,从上到下的路径.   

然后$\mathrm{d1}$ 和 $\mathrm{d2}$ 能合并的条件是 $\mathrm{d1=-d2 \times 10^{-dep2}}$.   

有这个式子后开两个 $\mathrm{map}$ 分别进行统计即可.  

这里注意这个 $\mathrm{m}$ 不一定是质数,所以需要用扩展欧几里得算法求逆元.  

#include 
#include  
#include 
#include 
#include 
#define N  100009 
#define ll long long 
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std; 
ll ans; 
int n, mod, inv, hd[N], to[N << 1], nex[N << 1], val[N << 1], edges, sn, root;  
int f[N], size[N], vis[N], dep[N], bu[N], base[N];    
mapc1, c2;    
ll exgcd(ll a, ll b, ll &x, ll &y) {
    if(!b) {
        x = 1, y = 0; 
        return a; 
    }
    ll gc = exgcd(b, a % b, x, y);  
    ll t = x;  
    x = y, y = t - (a / b) * y;  
    return gc; 
} 
void init() {
    ll x0, y0; 
    exgcd(10, mod, x0, y0);  
    x0 = (x0 % mod + mod) % mod;  
    inv = (int)x0;  
}
void add(int u, int v, int c) {
    nex[++edges] = hd[u]; 
    hd[u] = edges, to[edges] = v, val[edges] = c;  
}
void getroot(int x, int ff) {
    size[x] = 1, f[x] = 0; 
    for(int i = hd[x]; i ; i = nex[i]) {
        int v = to[i]; 
        if(v == ff || vis[v]) continue;  
        getroot(v, x), size[x] += size[v];  
        f[x] = max(f[x], size[v]);  
    }
    f[x] = max(f[x], sn - f[x]);    
    if(f[x] < f[root]) root = x;   
}
// 第一波.  
void calc(int x, int ff, int d1, int d2) {
    dep[x] = dep[ff] + 1;
    // 考虑 (d1, d2) 对前面的贡献.  
    ans += c2[d1];  
    ans += c1[(ll)(mod - d2) * bu[dep[x]] % mod];     
    for(int i = hd[x]; i ; i = nex[i]) {
        int v = to[i]; 
        if(v == ff || vis[v]) {
            continue;  
        }
        calc(v, x, (ll)(d1 + (ll)base[dep[x]] * val[i] % mod) % mod, (ll)((ll)d2 * 10 % mod + val[i]) % mod);        
    }   
}
// 第二波. 
void getdis(int x, int ff, int d1, int d2) {    
    dep[x] = dep[ff] + 1;        
    if(!d1) ++ ans; 
    if(!d2) ++ ans; 
    c1[d1] ++ ;  
    c2[(ll)(mod - d2) * bu[dep[x]] % mod] ++ ;   
    // c1 = d1  
    // c2 = -d2 * 10 ^ (-dep)     
    for(int i = hd[x]; i ; i = nex[i]) {
        int v = to[i]; 
        if(v == ff || vis[v]) {
            continue;  
        }
        getdis(v, x, (ll)(d1 + (ll)base[dep[x]] * val[i] % mod) % mod, (ll)((ll)d2 * 10 % mod + val[i]) % mod);    
    }   
}
void dfs(int x) {
    vis[x] = 1;  
    for(int i = hd[x]; i ; i = nex[i]) {
        int v = to[i]; 
        if(vis[v]) continue;   
        dep[x] = 0; 
        calc(v, x, val[i] % mod, val[i] % mod); 
        getdis(v, x, val[i] % mod, val[i] % mod);  
    }  
    // 当前计算完毕.  
    c1.clear(); 
    c2.clear(); 
    for(int i = hd[x]; i ; i = nex[i]) {
        int v = to[i]; 
        if(vis[v]) continue;  
        root = 0, sn = size[v], getroot(v, 0);  
        dfs(root);  
    }
}
int main() {
    // setIO("input");  
    scanf("%d%d", &n, &mod);  
    init(); 
    for(int i = 1; i < n ; ++ i) {
        int x, y, z; 
        scanf("%d%d%d", &x, &y, &z);  
        ++ x, ++ y;  
        add(x, y, z), add(y, x, z); 
    }
    bu[0] = 1, base[0] = 1; 
    for(int i = 1; i < N ; ++ i) {
        bu[i] = (ll)bu[i - 1] * inv % mod;  
        base[i] = (ll)base[i - 1] * 10 % mod;  
    }
    f[root = 0] = N, sn = n, getroot(1, 0);      
    dfs(root); 
    printf("%lld", ans); 
    return 0; 
}

相关