题意
给定一个字符串 \(s\)。你可以选择 \([l,r]\) 满足 \(1 \le l \le r \le |s|\),然后将 \(s_{l\sim r}\) 删掉,把两边的字符串拼起来,设得到的这个新字符串为 \(s'\)。
接下来,你可以选择一对新的 \(1 \le l' \le r' \le |s'|\),记 \(s'_{l'\sim r'}\) 组成的字符串为 \(s''\)。
给定另一个字符串 \(t\),问有多少种方案,可以使得到的 \(s''=t\)。两种方案不同当且仅当 \(l,r,l',r'\) 有一个不同。
\(1 \le |s| \le 4 \times 10^5\),\(1 \le |t| \le 2 \times 10^5\)
题解
我们记 \(n=|s|\),\(m=|t|\);记第一步删去的子串为 \([l_1,r_1]\),第二步保留的子串映射到原串的下标为 \([l_2,r_2]\)。
注意到从 \(l_2\) 开始,向右匹配了 \(t\) 的一个前缀;从 \(r_2\) 开始,向左匹配了 \(t\) 的一个后缀;并且,当我们固定了 \(l_2\) 和 \(r_2\) 之后,\(l_1\) 和 \(r_1\) 之间的距离将保持一个定值。因此,我们希望求出固定 \(l_2\) 和 \(r_2\) 之后,有多少种 \(l_1\)(\(l_2\)) 的取值方案。
为了方便叙述,我们记 \(f_i\) 表示 \(s_{i\sim n}\) 和 \(t\) 的最长公共前缀,\(g_i\) 表示 \(s_{1\sim i}\) 和 \(t\) 的最长公共后缀。不难注意到,\(l_1\) 的取值方案数和 \(f[l_2]+g[r_2]\) 有关,为 \(f[l_2]+g[r_2]-m+1\)。
然而,我们要求 \(r_2\ge l_2+m\),否则 \([l_1,r_1]\) 为空或不存在;并且 \(f[l_2]+g[r_2]<m\) 时,不能拼成一个完整的 \(t\),因此不能产生贡献。这是一个二维数点问题,我们从右往左扫描数组,用 BIT
维护 \(g[r]\) 即可。
有一些细节问题:我们注意到,\([l_1,r_1]\) 和 \([l_2,r_2]\) 无交时,\(s[l_2\sim r_2]=t\),有一些情况会被算漏。因此我们先求出 \(t\) 在 \(s\) 中的所有匹配位置,然后对每个匹配位置都计算 \([l_1,r_1]\) 的取值方案即可。
然而,对于上面这种情况,当 \(r_1=r_2\) 时,其贡献还会被二维数点部分计算一遍。因此我们执行 \(f_i\leftarrow \min(f_i,m-1)\),\(g_i\leftarrow \min(g_i,m-1)\),以去除 \(l_2\) 前缀或 \(r_2\) 后缀匹配为空的情况。
至于如何求出 \(f_i\) 和 \(g_i\),我们可以使用 exKMP
、SA
或者二分哈希。这里我们使用二分哈希。
AC 代码
| #include<iostream>
#include<cstring>
#include<cassert>
#define ll long long
#define ull unsigned long long
using namespace std;
const int N1 = 4e5 + 10;
const int N2 = 2e5 + 10;
const ull BASE = 131;
int n, m;
ll ans;
string s, t;
int f[N1], g[N1];
namespace Hash {
ull pw[N1];
ull sum1[N1], sum2[N2];
ull hsh1(int l, int r) {
return sum1[r] - sum1[l - 1] * pw[r - l + 1];
}
ull hsh2(int l, int r) {
return sum2[r] - sum2[l - 1] * pw[r - l + 1];
}
void init() {
sum1[0] = sum2[0] = pw[0] = 1;
for(int i = 1; i <= n; i++) pw[i] = pw[i - 1] * BASE;
for(int i = 1; i <= n; i++) sum1[i] = sum1[i - 1] * BASE + s[i];
for(int i = 1; i <= m; i++) sum2[i] = sum2[i - 1] * BASE + t[i];
}
}
using Hash::hsh1;
using Hash::hsh2;
namespace BIT {
ll sum[2][N2];
inline int lowbit(int x) { return x & -x; }
inline void add(int id, int p, ll v) {
assert(p > 0 && p <= m);
for(int i = m - p + 1; i <= m; i += lowbit(i)) sum[id][i] += v;
}
inline ll query(int id, int p) {
assert(p > 0 && p <= m);
ll res = 0;
for(int i = m - p + 1; i > 0; i -= lowbit(i)) res += sum[id][i];
return res;
}
};
int main() {
cin >> s >> t;
n = s.size();
m = t.size();
s = '#' + s;
t = '#' + t;
Hash::init();
for(int i = 1; i <= n - m + 1; i++) {
if(hsh1(i, i + m - 1) == hsh2(1, m)) {
if(i > 1) ans += (ll)i * (i - 1) / 2;
if(i + m - 1 < n) ans += (ll)(n - (i + m - 1) + 1) * (n - (i + m - 1)) / 2;
}
}
for(int i = 1; i <= n - m + 1; i++) {
int l = 0, r = m;
while(l < r) {
int mid = (l + r + 1) >> 1;
if(hsh1(i, i + mid - 1) == hsh2(1, mid)) {
l = mid;
} else r = mid - 1;
}
f[i] = l;
}
for(int i = m; i <= n; i++) {
int l = 0, r = m;
while(l < r) {
int mid = (l + r + 1) >> 1;
if(hsh1(i - mid + 1, i) == hsh2(m - mid + 1, m)) {
l = mid;
} else r = mid - 1;
}
g[i] = l;
}
for(ll i = 1; i <= n; i++) f[i] = min(f[i], m - 1);
for(ll i = 1; i <= n; i++) g[i] = min(g[i], m - 1);
for(int i = n - m; i >= 1; i--) {
if(g[i + m]) {
BIT::add(0, g[i + m], 1);
BIT::add(1, g[i + m], g[i + m]);
}
if(f[i]) {
ans += BIT::query(0, m - f[i]) * (f[i] - m + 1);
ans += BIT::query(1, m - f[i]);
}
}
cout << ans << endl;
return 0;
}
|