bzoj 3879: SvT 后缀自动机+虚树+树形dp

题意

有一个长度为n的仅包含小写字母的字符串S,下标范围为[1,n].
现在有若干组询问,对于每一个询问,我们给出若干个后缀(以其在S中出现的起始位置来表示),求这些后缀两两之间的LCP(LongestCommonPrefix)的长度之和.一对后缀之间的LCP长度仅统计一遍.
有S<=5*10^5,且Σt<=3*10^6.

分析

我们把S反过来后建后缀自动机,就转换成了求两两前缀的lcs。显然两个前缀的lcs等于其对应位置在parents树上lca的mx。
那么我们可以把这几个节点的虚树建出来,然后树形dp一下即可。

代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;

typedef long long LL;

const int N=1000005;

int n,m,fa[N],val[N],size[N],ch[N][26],sz,cnt,last[N],ls,lt[N],dep[N],rmq[N*2][25],bin[25],dfn[N],tim,dfn_tim,arr[N],lg[N*2],tot,a[N*13],stack[N],num[N];
struct edge{int to,next;}e[N*5];
LL ans;
char str[N];

int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

void ins(int id,int x)
{
    int p,q,np,nq;
    p=ls;num[id]=ls=np=++sz;val[np]=val[p]+1;
    for (;p&&!ch[p][x];p=fa[p]) ch[p][x]=np;
    if (!p) fa[np]=1;
    else
    {
        q=ch[p][x];
        if (val[q]==val[p]+1) fa[np]=q;
        else
        {
            nq=++sz;val[nq]=val[p]+1;
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            fa[nq]=fa[q];
            fa[q]=fa[np]=nq;
            for (;ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
        }
    }
}

bool cmp(int x,int y)
{
    return dfn[x]<dfn[y];
}

void addedge(int u,int v)
{
    e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
}

void addedge1(int u,int v)
{
    e[++cnt].to=v;e[cnt].next=lt[u];lt[u]=cnt;
}

void dfs(int x)
{
    dfn_tim++;dfn[x]=dfn_tim;dep[x]=dep[fa[x]]+1;
    tim++;rmq[tim][0]=x;arr[x]=tim;
    for (int i=last[x];i;i=e[i].next)
    {
        if (e[i].to==fa[x]) continue;
        dfs(e[i].to);
        tim++;rmq[tim][0]=x;
    }
}

void get_rmq()
{
    for (int i=1;i<=tim;i++) lg[i]=log(i)/log(2);
    for (int j=1;j<=lg[tim];j++)
        for (int i=1;i+bin[j]-1<=tim;i++)
            rmq[i][j]=dep[rmq[i][j-1]]<dep[rmq[i+bin[j-1]][j-1]]?rmq[i][j-1]:rmq[i+bin[j-1]][j-1];
}

int get_lca(int x,int y)
{
    int l=min(arr[x],arr[y]),r=max(arr[x],arr[y]),len=lg[r-l+1];
    return dep[rmq[l][len]]<dep[rmq[r-bin[len]+1][len]]?rmq[l][len]:rmq[r-bin[len]+1][len];
}

void dp(int x)
{
    for (int i=lt[x];i;i=e[i].next)
    {
        dp(e[i].to);
        ans+=(LL)val[x]*size[x]*size[e[i].to];
        size[x]+=size[e[i].to];
        size[e[i].to]=0;lt[e[i].to]=0;
    }
}

LL solve()
{
    int top=1,tmp=cnt;
    stack[1]=1;
    for (int i=1;i<=tot;i++)
    {
        int x=a[i],lca=get_lca(x,stack[top]);size[x]=1;
        while (top>1&&dep[stack[top-1]]>dep[lca]) addedge1(stack[top-1],stack[top]),top--;
        if (dep[stack[top]]>dep[lca]) addedge1(lca,stack[top]),top--;
        if (lca!=stack[top]) stack[++top]=lca;
        stack[++top]=x;
    }
    while (top>1) addedge1(stack[top-1],stack[top]),top--;
    ans=0;
    dp(1);
    size[1]=0;lt[1]=0;cnt=tmp;
    return ans;
}

int main()
{
    bin[0]=1;
    for (int i=1;i<=20;i++) bin[i]=bin[i-1]*2;
    n=read();m=read();
    scanf("%s",str+1);
    for (int i=1;i*2<=n;i++) swap(str[i],str[n-i+1]);
    ls=sz=1;
    for (int i=1;i<=n;i++) ins(i,str[i]-'a');
    for (int i=2;i<=sz;i++) addedge(fa[i],i);
    dfs(1);
    get_rmq();
    while (m--)
    {
        tot=read();
        for (int i=1;i<=tot;i++) a[i]=num[n-read()+1];
        sort(a+1,a+tot+1,cmp);
        tot=unique(a+1,a+tot+1)-a-1;
        printf("%lld\n",solve());
    }
    return 0;
}

版权声明:本文为qq_33229466原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。