[CF1442C] Graph Transpositions - 分层图最短路
[CF1442C] Graph Transpositions - 分层图最短路
Description
给你一个\(n\)个顶点和\(m\)条边的有向图。顶点编号从\(1\)到\(n\)。顶点\(1\)处有一个标记。
你可以进行以下两种操作:
- 移动标记:如果存在一条\(u\to v\)的边,将标记从\(u\)移动到\(v\),这个操作需要\(1\)秒。
- 图翻转:翻转图上的所有边的方向,将图上每一条边\(u\to v\)替换为\(v\to u\),第\(k\)次使用这个操作需要耗时\(2^{k-1}\)秒。
你需要找到将标记从\(1\)移动到\(n\)的最短时间,请将答案对\(998,244,353\)取模。
Solution
朴素分层图有 \(O(n^2)\) 个点,需另辟蹊径
设 \(C=\lceil \log_2 n \rceil\),则当 \(k \ge C\) 时,层数最小不会使得答案更劣
对 \(k\le C\) 的部分建分层图跑最短路
层内连边 \(p \to q (cost=1)\),第 \(k\) 层到第 \(k+1\) 层 \(i_k \to i_{k+1} (cost = 2^k)\)
如果无解,说明 \(k>C\),此时我们优先最小化层数
每次做完层内转移,然后集体转移到下一层
若 \(k \bmod 2=0\) 则用正图,否则用反图
先在层内转移,内连边 \(p \to q (cost=1)\)
后再层间转移,只要 \(n\) 点还未到达,所有点就一起转移,从 \(k\) 层到 \(k+1\) 层,整体加上代价 \(2^k\)
这个代价额外存一个偏移量中即可,\(f[]\) 中记录的代价只是 \(k \le C\) 部分的代价和 \(k >C\) 部分的层内转移代价,这样 \(f[]\) 中的数字不用取模,可以比较大小
由于边权只有 \(0/1\),总体的时间复杂度还是 \(O(m \log n)\) 的
最短路过程中记录代价用二维数组 \(dis[][]\),后续过程用 \(f[]\)
#include
using namespace std;
#define int long long
const int mod = 998244353;
const int N = 400005;
const int M = 25;
int qpow(int p, int q)
{
return (q & 1 ? p : 1) * (q ? qpow(p * p % mod, q / 2) : 1) % mod;
}
int n, m;
int dis[M][N], f[N];
bool vis[M][N];
vector g[2][N];
void spfa()
{
queue> que;
dis[0][1] = 0;
que.push({0, 1});
vis[0][1] = 1;
while (que.size())
{
auto [k, p] = que.front();
que.pop();
for (int q : g[k & 1][p])
{
if (dis[k][q] > dis[k][p] + 1)
{
dis[k][q] = dis[k][p] + 1;
if (vis[k][q] == 0)
que.push({k, q}), vis[k][q] = 1;
}
}
if (k < 18)
if (dis[k + 1][p] > dis[k][p] + (1ll << k))
{
dis[k + 1][p] = dis[k][p] + (1ll << k);
if (vis[k + 1][p] == 0)
que.push({k + 1, p}), vis[k + 1][p] = 1;
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin >> n >> m;
for (int i = 1; i <= m; i++)
{
int t1, t2;
cin >> t1 >> t2;
g[0][t1].push_back(t2);
g[1][t2].push_back(t1);
}
memset(dis, 0x3f, sizeof dis);
memset(f, 0x3f, sizeof f);
spfa();
int ans = 1e18;
for (int i = 0; i < 18; i++)
ans = min(ans, dis[i][n]);
if (ans < 1e9)
{
cout << ans << endl;
}
else
{
f[1] = 0;
queue que, que_next;
que.push(1);
que_next.push(1);
int delta = 0;
for (int k = 0;; k++)
{
while (que.size())
{
int p = que.front();
que.pop();
for (int q : g[k & 1][p])
{
if (f[q] > f[p] + 1)
{
f[q] = f[p] + 1;
que.push(q);
que_next.push(q);
}
}
}
while (que_next.size())
{
que.push(que_next.front());
que_next.pop();
}
if (f[n] < 1e9)
{
cout << (f[n] + delta) % mod << endl;
return 0;
}
delta += qpow(2, k);
delta %= mod;
}
}
}