维护序列

洛谷-P2023-维护序列open in new window

实现

#include <iostream>
#define plus PlUs
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;

typedef long long LL;
const int N = 100010;

struct node {
    int l, r;
    LL data;
    LL add, mul; // 加法标记,乘法标记
    #define l(x) t[x].l
    #define r(x) t[x].r
    #define data(x) t[x].data
    #define add(x) t[x].add
    #define mul(x) t[x].mul
};
int n, M, q;
LL a[N];
node t[N << 2];

void update (int rt) {
    data(rt) = (data(ls) + data(rs)) % M;
}
void build (int rt, int l, int r) {
    l(rt) = l, r(rt) = r, mul(rt) = 1; // 初始化 mul 标记
    if (l == r) {
        data(rt) = a[l];
        return;
    }
    int mid = l + r >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    update(rt);
}
void spread (int rt) {
    data(ls) = (data(ls) * mul(rt) + add(rt) * (r(ls) - l(ls) + 1)) % M;
    data(rs) = (data(rs) * mul(rt) + add(rt) * (r(rs) - l(rs) + 1)) % M;

    mul(ls) = mul(ls) * mul(rt) % M;
    mul(rs) = mul(rs) * mul(rt) % M;

    add(ls) = (add(ls) * mul(rt) + add(rt)) % M;
    add(rs) = (add(rs) * mul(rt) + add(rt)) % M;

    mul(rt) = 1;
    add(rt) = 0;
}
void plus (int rt, int l, int r, int val) {
    if (l <= l(rt) && r(rt) <= r) {
        data(rt) = (data(rt) + val * (r(rt) - l(rt) + 1)) % M;
        add(rt) = (add(rt) + val) % M;
        return;
    }
    spread(rt);
    int mid = l(rt) + r(rt) >> 1;
    if (l <= mid) plus(ls, l, r, val);
    if (r >= mid + 1) plus(rs, l, r, val);
    update(rt);
}
void multiply (int rt, int l, int r, int val) {
    if (l <= l(rt) && r(rt) <= r) {
        data(rt) = data(rt) * val % M;
        mul(rt) = mul(rt) * val % M;
        add(rt) = add(rt) * val % M;
        return;
    }
    spread(rt);
    int mid = l(rt) + r(rt) >> 1;
    if (l <= mid) multiply(ls, l, r, val);
    if (r >= mid + 1) multiply(rs, l, r, val);
    update(rt);
}
LL query (int rt, int l, int r) {
    if (l <= l(rt) && r(rt) <= r) return data(rt);
    spread(rt);
    int mid = l(rt) + r(rt) >> 1;
    LL res = 0;
    if (l <= mid) res = (res + query(ls, l, r)) % M;
    if (r >= mid + 1) res = (res + query(rs, l, r)) % M;
    return res;
}

int main () {
    scanf("%d%d", &n, &M);
    for (int i = 1; i <= n; ++ i) scanf("%lld", &a[i]);
    scanf("%d", &q);

    build(1, 1, n);
    while (q --) {
        int opt;
        scanf("%d", &opt);
        if (opt == 1) {
            int l, r, val;
            scanf("%d%d%d", &l, &r, &val);
            multiply(1, l, r, val);
        } else if (opt == 2) {
            int l, r, val;
            scanf("%d%d%d", &l, &r, &val);
            plus(1, l, r, val);
        } else {
            int l, r;
            scanf("%d%d", &l, &r);
            printf("%lld\n", query(1, l, r));
        }
    }
    return 0;
}
最后修改于: