题目链接:
题意:给出一棵树,每个节点有一个值w。若干询问,每个询问s,t,a,b,k,询问从s节点走到t节点权值在[a,b]之间的第k个节点。
思路:
(1)首先把从根出发到每个点的权值分布用函数式线段树全部记录,则通过函数式线段树的减法操作以及计算两个点LCA就可以得到任意两个点路径上权值a到b的点个数。
(2)记录每个点的2次幂祖先,可以计算两点LCA;
(3)分别计算s到LCA和LCA 到t中权值在[a,b]的点个数left和right,如果left+right小于k,则输出-1;如果left不少于k,则答案点肯定在s到LCA上,否则在LCA到t上。
#include#include #include #include using namespace std;struct node{ int a,b,L,R,s,mid;};const int MAX=100005;node p[MAX*20];int root[MAX],tot;int f[MAX][25],n,m,w[MAX],dep[MAX];vector g[MAX];int build(int a,int b){ int k=++tot; p[k].a=a; p[k].b=b; p[k].s=0; p[k].mid=(a+b)>>1; if(a==b) return k; p[k].L=build(a,p[k].mid); p[k].R=build(p[k].mid+1,b); return k;}int change(int c,int s){ int k=++tot; p[k]=p[c]; p[k].s++; if(p[k].a==p[k].b) return k; if(s<=p[k].mid) p[k].L=change(p[c].L,s); else p[k].R=change(p[c].R,s); return k;}void DFS(int u,int pre,int depth){ f[u][0]=pre; dep[u]=depth; root[u]=change(root[pre],w[u]); int i,v; for(i=0;i dep[b]) { temp=a; a=b; b=temp; } if(dep[a] =0;i--) { if(f[a][i]!=f[b][i]&&f[a][i]&&f[b][i]) { a=f[a][i]; b=f[b][i]; } } a=f[a][0]; return a;}int cal(int x,int y,int a,int b){ if(p[x].b b) return 0; if(a<=p[x].a&&p[x].b<=b) return p[x].s-p[y].s; return cal(p[x].L,p[y].L,a,b)+cal(p[x].R,p[y].R,a,b);}int getCnt(int x,int y,int a,int b){ return cal(root[x],root[f[y][0]],a,b);}int search(int x,int y,int k,int a,int b,int t){ if(x==y) return x; if(y==f[x][0]) { if(k==1&&a<=w[x]&&w[x]<=b) return x; return y; } if(f[x][t]==0||dep[f[x][t]]<=dep[y]) { return search(x,y,k,a,b,t-1); } int cnt=getCnt(x,f[x][t],a,b); if(k<=cnt) return search(x,f[x][t],k,a,b,t-1); else return search(f[f[x][t]][0],y,k-cnt,a,b,t-1);}void deal(){ int i,s,t,a,b,k,temp,lca; int left,right; while(m--) { scanf("%d%d%d%d%d",&s,&t,&a,&b,&k); lca=getLCA(s,t); left=getCnt(s,lca,a,b); right=getCnt(t,lca,a,b); temp=(a<=w[lca]&&w[lca]<=b); if(left+right-temp =k) i=search(s,lca,k,a,b,19); else i=search(t,lca,left+right-temp-k+1,a,b,19); printf("%d\n",i); }}int main(){ while(scanf("%d%d",&n,&m)!=-1) { int i,j,u,v; for(i=1;i<=n;i++) g[i].clear(); memset(f,0,sizeof(f)); for(i=1;i<=n;i++) scanf("%d",w+i); for(i=1;i