树状数组 线段树模板 c++

树状数组 线段树模板

参考OIwiki(好东西)

树状数组

  • 单点修改 区间查询总和
const int maxN = 5e5+5;
int a[maxN],c[maxN];//
int m,n;
int lowbit(int x) {
  return x & -x;
}
void add(int x, int k) {
  while (x <= n) {  
    c[x] = c[x] + k;
    x = x + lowbit(x);
  }
}
int getsum(int x) { 
  int ans = 0;
  while (x >= 1) {
    ans = ans + c[x];
    x = x - lowbit(x);
  }
  return ans;
}
int main(){
    int x,y,z;
    cin>>n>>m;
    c[0]=0;
    for(int i = 1;i<=n;i++){
        cin>>x;
        add(i,x);
    }
    for(int i=1;i<=m;i++){
        cin>>x>>y>>z;
        if(x==1)
            add(y,z);
        else {
            cout<<getsum(z)-getsum(y-1)<<endl;
        }
    }
    return 0;
}
  • 区间修改 单点查询
const int MAXN = 5e5+5;
int m,n;
ll t1[MAXN], t2[MAXN];//t1 t2两个差分数组
inline ll lowbit(ll x) { return x & (-x); }
void add(ll k, ll v) {
  ll v1 = k * v;
  while (k <= n) {
    t1[k] += v, t2[k] += v1;
    k += lowbit(k);
  }
}
ll getsum(ll *t, ll k) {
  ll ret = 0;
  while (k) {
    ret += t[k];
    k -= lowbit(k);
  }
  return ret;
}
void add1(ll l, ll r, ll v) {
  add(l, v), add(r + 1, -v);  // 将区间加差分为两个前缀加
}
ll getsum1(ll l, ll r) {
  return (r + 1ll) * getsum(t1, r) - 1ll * l * getsum(t1, l - 1) -
         (getsum(t2, r) - getsum(t2, l - 1));
}
int main(){
    ll x,y=0,z,ans;
    cin>>n>>m;
    for(int i = 1;i<=n;i++){
        cin>>x;
        add(i,x-y);
        y = x;
    }
    for(int i=1;i<=m;i++){
        cin>>x;
        if(x==1){
            cin>>x>>y>>z;
            add1(x,y,z);
        }else {
            ans=0;
            cin>>x;
            cout<<getsum(t1,x)<<endl;
        }
    }
    return 0;
}

线段树

  • 单点修改 区间查询
const int MAXN = 2e5 + 5;
int n, m,d[MAXN << 2], a[MAXN];
void build(int s, int t, int p = 1)
{
    if (s == t) {
        d[p] = a[s];
        return;
    }
    int m = (s + t) / 2;
    build(s, m, p * 2), build(m + 1, t, p * 2 + 1);
    d[p] = max(d[p * 2], d[(p * 2) + 1]);
}
void update(int in, int c, int s = 1, int t = n, int p = 1)
{
    if (in == s && s == t){
        d[p] = c;
        return;
    }
    int m = (s + t) >> 1;
    if (in <= m)
        update(in, c, s, m, p * 2);
    else
        update(in, c, m + 1, t, p * 2 + 1);
    d[p] = max(d[p * 2], d[p * 2 + 1]);
}

int getsum(int l, int r, int s = 1, int t = n, int p = 1)
{
    if (l <= s && t <= r)
        return d[p];
    int m = (s + t) >> 1;
    int sum = 0;
    if (l <= m)
        sum = max(sum, getsum(l, r, s, m, p * 2));
    if (r > m)
        sum = max(sum, getsum(l, r, m + 1, t, p * 2 + 1));
    return sum;
}
int main()
{
    int t, x, y;
    char com[2];
    scanf("%d%d", &n, &m) != EOF
    for (int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
    build(1,n);
    for (int i = 1; i <= m; i++)
    {
        scanf("%s", com);
        scanf("%d%d", &x, &y);
        if (com[0] == 'U')
            update(x, y);
        else 
            cout << getsum(x, y) << endl;   
    }
    return 0;
}
  • 区间修改 区间查询总和
typedef long long ll;
//十年OI一场空,不开longlong见祖宗
const int MAXN = 100005;
int n, m;
ll d[MAXN << 2], a[MAXN] , b[MAXN<<2];//a原数组,b懒惰标记,d线段树
void build(int s, int t, ll p=1) {
    if (s == t) {
        d[p] = a[s];
        return;
    }
    ll m = (s + t) / 2;
    build(s, m, p * 2), build(m + 1, t, p * 2 + 1);
    d[p] = d[p * 2] + d[(p * 2) + 1];
}
void pushdown(int p ,int m) {//下传懒惰标记
    b[p * 2] += b[p];
    b[p * 2 + 1] += b[p];
    d[p * 2] += b[p] * (m - (m>>1));
    d[p * 2 + 1] += b[p] * (m >> 1);
    b[p] = 0;
}
void update(int l, int r, ll c, int s=1, int t=n, int p=1) {
    if (l <= s && t <= r) {
        d[p] += (t - s + 1) * c, b[p] += c;
        return;
    }
    int m = (s + t) / 2;
    if (b[p] && s != t) pushdown(p,t-s+1);
    if (l <= m) update(l, r, c, s, m, p * 2);
    if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
    d[p] = d[p * 2] + d[p * 2 + 1];
}
ll getsum(int l, int r, int s=1, int t=n, int p=1) {
    if (l <= s && t <= r) return d[p];
    int m = (s + t) / 2;
    if (b[p]) pushdown(p,t-s+1);
    ll sum = 0;
    if (l <= m) sum = getsum(l, r, s, m, p * 2);
    if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
    return sum;
}
int main() {
    int t, x, y;
    ll z;
    char com[2];
    scanf("%d%d", &n, &m);
    memset(a, 0, sizeof(a));
    memset(d, 0, sizeof(d));
    for (int i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    build(1,n);
    for (int i = 1; i <= m; i++) {
        scanf("%s", com);
        if (com[0] == 'C') {
            scanf("%d%d%lld", &x, &y, &z);
            update(x, y, z);
        } else {
            scanf("%d%d", &x, &y);
            printf("%lld\n",getsum(x, y));
        }
    }
    return 0;
}

版权声明:本文为weixin_46879396原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。