三维偏序

模板题

洛谷-P3810-陌上花开open in new window

#include <iostream>
#include <algorithm>
using namespace std;

const int N = 100010, M = 200010;
struct node {
    int x, y, z;
    int cnt;
    int res;
    bool operator < (const node& o) const {
        if (x == o.x) {
            if (y == o.y)
                return z < o.z;
            return y < o.y;
        }
        return x < o.x;
    }
};
int n, m, k;
node a[N], b[N], temp[N];
int c[M];
int res[N];

void unique () {
    sort(a + 1, a + n + 1);
    int cnt = 0;
    for (int i = 1; i <= n; ++ i) {
        ++ cnt;
        if (a[i].x != a[i + 1].x
            || a[i].y != a[i + 1].y
            || a[i].z != a[i + 1].z
        ) {
            b[++ m] = { a[i].x, a[i].y, a[i].z, cnt, 0 };
            cnt = 0;
        }
    }
}
int lowbit (int x) {
    return x & -x;
}
void add (int idx, int val) {
    while (idx <= k) {
        c[idx] += val;
        idx += lowbit(idx);
    }
}
int sum (int idx) {
    int res = 0;
    while (idx >= 1) {
        res += c[idx];
        idx -= lowbit(idx);
    }
    return res;
}
void cdq (int l, int r) {
    if (l >= r) return;

    int mid = l + r >> 1;
    cdq(l, mid);
    cdq(mid + 1, r);

    int i = l, j = mid + 1, k = l;
    while (i <= mid && j <= r) {
        if (b[i].y <= b[j].y) {
            add(b[i].z, b[i].cnt);
            temp[k ++] = b[i ++];
        } else { // b[i].y > b[j].y
            b[j].res += sum(b[j].z);
            temp[k ++] = b[j ++];
        }
    }
    while (i <= mid) {
        add(b[i].z, b[i].cnt);
        temp[k ++] = b[i ++];
    }
    while (j <= r) {
        b[j].res += sum(b[j].z);
        temp[k ++] = b[j ++];
    }
    for (int i = l; i <= mid; ++ i) add(b[i].z, -b[i].cnt);
    for (int i = l; i <= r; ++ i) b[i] = temp[i];
}

int main () {
    scanf("%d%d", &n, &k);
    for (int i = 1; i <= n; ++ i)
        scanf("%d%d%d", &a[i].x, &a[i].y, &a[i].z);

    unique();
    cdq(1, m);
    for (int i = 1; i <= m; ++ i)
        res[b[i].res + b[i].cnt] += b[i].cnt;
    for (int i = 1; i <= n; ++ i)
        printf("%d\n", res[i]);
    return 0;
}
最后修改于: