窗内的星星

AcWing-248-窗内的星星open in new window

分析

见《进阶指南》第220页。

实现

#include <iostream>
#include <algorithm>
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;

typedef unsigned int UI;
const int N = 10010;
struct boundary {
    UI x, y1, y2;
    int c;
    bool operator < (const boundary& o) const {
        if (x == o.x)
            return c < o.c;
        return x < o.x;
    }
};
boundary a[2 * N];
int m;
UI nums[2 * N];
struct node {
    int l, r;
    int data; // 区间最大值
    int add;
    #define l(x) t[x].l
    #define r(x) t[x].r
    #define data(x) t[x].data
    #define add(x) t[x].add
};
node t[N << 3];

void discrete () {
    sort(nums + 1, nums + m + 1);
    m = unique(nums + 1, nums + m + 1) - (nums + 1);
}
int query (int x) {
    return lower_bound(nums + 1, nums + m + 1, x) - nums;
}
void build (int rt, int l, int r) {
    l(rt) = l, r(rt) = r;
    data(rt) = add(rt) = 0;
    if (l == r) return;
    int mid = l + r >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
}
void spread (int rt) {
    if (add(rt) != 0) {
        data(ls) += add(rt);
        data(rs) += add(rt);

        add(ls) += add(rt);
        add(rs) += add(rt);

        add(rt) = 0;
    }
}
void modify (int rt, int l, int r, int val) {
    if (l <= l(rt) && r(rt) <= r) {
        data(rt) += val;
        add(rt) += val;
        return;
    }
    spread(rt);
    int mid = l(rt) + r(rt) >> 1;
    if (l <= mid) modify(ls, l, r, val);
    if (r >= mid + 1) modify(rs, l, r, val);
    data(rt) = max(data(ls), data(rs));
}

int main () {
    int n, W, H;
    while (scanf("%d%d%d", &n, &W, &H) != EOF) {
        m = 0;
        for (int i = 1, x, y, c; i <= n; ++ i) {
            scanf("%d%d%d", &x, &y, &c);
            a[i] = { x, y, y + H - 1, c };
            a[n + i] = { x + W, y, y + H - 1, -c };
            nums[++ m] = y;
            nums[++ m] = y + H - 1;
        }

        discrete();
        for (int i = 1; i <= n; ++ i) {
            a[i].y1 = a[n + i].y1 = query(a[i].y1);
            a[i].y2 = a[n + i].y2 = query(a[i].y2);
        }

        build(1, 1, m);
        sort(a + 1, a + 2 * n + 1);
        int res = 0;
        for (int i = 1; i <= 2 * n; ++ i) {
            modify(1, a[i].y1, a[i].y2, a[i].c);
            res = max(res, data(1));
        }
        printf("%d\n", res);
    }
    return 0;
}
最后修改于: