数位 DP 学习笔记
概念
Q:什么是数位 DP?
A:就是一种对数进行的 DP。
Q:常见题型?
A:找出区间 \([L, R]\) 中满足某种条件 \(f(i)\) 的数的个数,限制 \(f(i)\) 只与 \(i\) 的每一位上的数字有关,而与 \(i\) 的值无关。
Q:如何求解?
A:一般都有套路,具体下面会讲。
常见思路
数位 DP 一般使用记忆化搜索实现。
根据某种套路,我们将询问区间 \([L, R]\) 拆成 \([0,R]\) 和 \([0,L-1]\),输出答案的时候直接前缀和一下即可。
一般设 \(dp_{cur,lst,\dots,f,g}\) 表示当前填到了第 \(cur\) 位,上一位的状态是 \(lst\)(有时候不止要记上一位的,可能还要记前面很多位的,因此打了省略号),\(f\) 表示前 \(cur\) 位是否与上界相同,\(g\) 则表示有无前导零。
由于数位 DP 基本上都是套个模板,因此这里就直接给出模板。
int l, r; //答案区间
int dp[N][N][...][2][2]; //DP 数组
int tot, b[N]; //将数按位拆开
int dfs(int cur/*当前填到了哪一位*/, int lst, .../*上一位的状态*/, bool f/*前 cur 位是否与上界相同*/, bool g/*有无前导零*/) //数位 DP 函数
{
if (cur == tot + 1) return 1; //边界
if (dp[cur][lst][...][f][g] != -1) return dp[cur][lst][...][f][g]; //记忆化搜索,已经搜过就直接返回
int v = 9;
if (f) v = b[cur]; //求出当前枚举的位上的数的上界
int ans = 0;
for (int i = 0; i <= v; i+=1) //枚举这一位填什么数
{
if (g == true) //有前导零
{
if (i == 0) ans += dfs(cur + 1, -1, ..., f && (i == v), true); //这一位不填数
else ans += dfs(cur + 1, i, ..., f && (i == v), false); //这一位填上 i
}
else /*这里根据题目不同可能要加上各种限制条件*/ ans += dfs(cur + 1, i, ..., f && (i == v), false); //累加这一位填上 i 的答案
}
return dp[cur][lst][...][f][g] = ans; //记忆化
}
inline int solve(int x)
{
memset(dp, -1, sizeof dp); //DP 数组初始化
tot = 0;
while (x)
{
b[++tot] = x % 10;
x /= 10;
}
reverse(b + 1, b + 1 + tot); //将上界按位拆开
return dfs(1, -1, true, true); //进行一次 DP
}
int main()
{
l = gi (), r = gi ();
printf("%d\n", solve(r) - solve(l - 1)); //前缀和计算答案
return 0;
}
具体的实践还是需要依题目而定,因此我们来看几道例题。
例题
T1. [SCOI2009] windy 数
可以说是数位 DP 的一道入门题。
根据模板,我们只需要记录上一个位置填的数是什么,然后转移的时候判断当前填的数是否满足条件即可。
#include
#define DEBUG fprintf(stderr, "Passing [%s] line %d\n", __FUNCTION__, __LINE__)
#define File(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
using namespace std;
typedef long long LL;
typedef pair PII;
typedef pair PIII;
template
inline T gi()
{
T f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f * x;
}
const int INF = 0x3f3f3f3f, N = 13, M = N << 1;
int l, r, tot;
int dp[N][N][2][2];
int b[N];
int dfs(int cur, int lst, bool f, bool g)
//填到第 cur 位,上一位的数是 lst,f 表示是否与上界相等,g 表示是否有前导零
{
if (cur == tot + 1) return 1;
if (dp[cur][lst][f][g] != -1) return dp[cur][lst][f][g];
int v = 9;
if (f) v = b[cur];
int ans = 0;
for (int i = 0; i <= v; i+=1)
{
if (g == true)
{
if (i == 0) ans += dfs(cur + 1, -1, f && (i == v), true);
else ans += dfs(cur + 1, i, f && (i == v), false);
}
else if (abs(i - lst) >= 2) /*判断能否转移*/ ans += dfs(cur + 1, i, f && (i == v), false);
}
return dp[cur][lst][f][g] = ans;
}
inline int solve(int x)
{
memset(dp, -1, sizeof dp);
tot = 0;
while (x)
{
b[++tot] = x % 10;
x /= 10;
}
reverse(b + 1, b + 1 + tot);
return dfs(1, -1, true, true);
}
int main()
{
l = gi (), r = gi ();
printf("%d\n", solve(r) - solve(l - 1));
return 0;
}
T2. [ZJOI2010]数字计数
对每一个数码进行一次数位 DP,需要记录一下填到第 cur 位当前要统计的数的出现次数,别的都没什么太大差别。
#include
#define DEBUG fprintf(stderr, "Passing [%s] line %d\n", __FUNCTION__, __LINE__)
#define File(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
using namespace std;
typedef long long LL;
typedef pair PII;
typedef pair PIII;
template
inline T gi()
{
T f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f * x;
}
const int INF = 0x3f3f3f3f, N = 15, M = N << 1;
LL l, r;
int tot, b[N];
LL dp[N][N][2][2];
LL dfs(int now, int cur, int lst, bool f, bool g)
{
if (cur == tot + 1) return lst;
if (dp[cur][lst][f][g] != -1) return dp[cur][lst][f][g];
int v = 9;
if (f) v = b[cur];
LL ans = 0;
for (int i = 0; i <= v; i+=1)
{
if (g)
{
if (i == 0) ans += dfs(now, cur + 1, 0, f && (i == v), true);
else ans += dfs(now, cur + 1, lst + (i == now), f && (i == v), false);
}
else ans += dfs(now, cur + 1, lst + (i == now), f && (i == v), false);
}
return dp[cur][lst][f][g] = ans;
}
inline LL solve(int now, LL x)
{
tot = 0;
while (x)
{
b[++tot] = x % 10;
x /= 10;
}
reverse(b + 1, b + 1 + tot);
memset(dp, -1, sizeof dp);
return dfs(now, 1, 0, true, true);
}
int main()
{
//File("");
l = gi (), r = gi ();
for (int i = 0; i <= 9; i+=1)
printf("%lld ", solve(i, r) - solve(i, l - 1));
return 0;
}
T3. [CQOI2016]手机号码
这题要记录的状态可能多一些……
对于当前状态,我们需要记录 \(p1\)、\(p2\)(前两个数是什么)、\(ok\)(是否已经出现至少 \(3\) 个相邻的相同数字)、\(h8\)、\(h4\)(分别为是否已经出现了 \(8\) 或 \(4\)),因此转移的时候有一些细节需要注意。
#include
#define DEBUG fprintf(stderr, "Passing [%s] line %d\n", __FUNCTION__, __LINE__)
#define File(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
using namespace std;
typedef long long LL;
typedef pair PII;
typedef pair PIII;
template
inline T gi()
{
T f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f * x;
}
const int INF = 0x3f3f3f3f, N = 15, M = N << 1;
LL l, r;
int tot, b[N];
LL dp[N][N][N][2][2][2][2][2];
LL dfs(int cur, int p1, int p2, bool ok, bool h8, bool h4, bool f, bool g)
{
if (cur == tot + 1) return ok == true;
if (dp[cur][p1][p2][ok][h8][h4][f][g] != -1) return dp[cur][p1][p2][ok][h8][h4][f][g];
int v = 9;
if (f) v = b[cur];
LL ans = 0;
for (int i = 0; i <= v; i+=1)
{
if (g)
{
if (i == 0) ans += dfs(cur + 1, p1, p2, ok, h8, h4, f && (i == v), g);
else ans += dfs(cur + 1, i, p1, ok, (i == 8), (i == 4), f && (i == v), false);
}
else
{
if (h8 && (i == 4)) continue;
if (h4 && (i == 8)) continue; //同时出现了 8 和 4
if (p1 == p2 && i == p1)
ans += dfs(cur + 1, i, p1, true, h8 || (i == 8), h4 || (i == 4), f && (i == v), false); //已经出现了 3 个相邻的相同数字
else
ans += dfs(cur + 1, i, p1, ok, h8 || (i == 8), h4 || (i == 4), f && (i == v), false);
}
}
return dp[cur][p1][p2][ok][h8][h4][f][g] = ans;
}
inline LL solve(LL x)
{
tot = 0;
while (x)
{
b[++tot] = x % 10;
x /= 10;
}
reverse(b + 1, b + 1 + tot);
memset(dp, -1, sizeof dp);
return dfs(1, -1, -1, false, false, false, true, true);
}
int main()
{
l = gi (), r = gi ();
printf("%lld\n", solve(r) - solve(l - 1));
return 0;
}
T4. AcWing310 启示录
与上一道题目差不多的套路,记录一下前两位是否为 6 以及是否已经出现了连续至少 \(3\) 个 6。
注意还需要进行一个二分,找到第一个 \(ans\) 使得 \(ans\) 是满足 \(\le ans\) 的魔鬼数个数为 \(k\) 的数中最小的一个。
#include
#define DEBUG fprintf(stderr, "Passing [%s] line %d\n", __FUNCTION__, __LINE__)
#define File(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
using namespace std;
typedef long long LL;
typedef pair PII;
typedef pair PIII;
template
inline T gi()
{
T f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f * x;
}
const int INF = 0x3f3f3f3f, N = 20, M = N << 1;
int T, x;
int tot, b[N];
LL dp[N][2][2][2][2];
LL dfs(int cur, bool p1, bool p2, bool ok, bool f)
{
if (cur == tot + 1) return (ok == true);
if (dp[cur][p1][p2][ok][f] != -1) return dp[cur][p1][p2][ok][f];
int v = 9;
if (f) v = b[cur];
LL ans = 0;
for (int i = 0; i <= v; i+=1)
{
ans += dfs(cur + 1, i == 6, p1, ok || (p1 && p2 && (i == 6)), f && (i == v));
}
return dp[cur][p1][p2][ok][f] = ans;
}
inline LL solve(LL mid)
{
tot = 0;
while (mid)
{
b[++tot] = mid % 10;
mid /= 10;
}
reverse(b + 1, b + 1 + tot);
memset(dp, -1, sizeof dp);
return dfs(1, false, false, false, true);
}
int main()
{
T = gi ();
while (T--)
{
x = gi ();
LL l = 0, r = 1000000000000000ll, ans = 0;
while (l <= r)
{
LL mid = (l + r) >> 1;
if (solve(mid) >= x) ans = mid, r = mid - 1;
else l = mid + 1;
}
printf("%lld\n", ans);
}
return 0;
}
T5. AcWing311 月之谜(思考题)
此题中我们发现“整除”这个操作并不好直接处理,有什么方式可以完成呢?
提示:考虑到单纯的数位 DP 复杂度并不高,于是可以枚举一下这个数的各位数字之和再进行数位 DP。