整体分治

原理

见《进阶指南》第247页。

模板题

AcWing-255-第K小数open in new window

#include <iostream>
#define inf 0x3f3f3f3f
using namespace std;

const int N = 100010, M = 10010;
struct operation {
    int type; // 操作的类型,0:修改,1:查询
    int idx;
    int val;
    int l, r, k;
};
int n, m;
operation a[N + M], la[N + M], ra[N + M];
int c[N], res[N];

int lowbit (int x) {
    return x & -x;
}
void add (int idx, int val) {
    while (idx <= n) {
        c[idx] += val;
        idx += lowbit(idx);
    }
}
int sum (int idx) {
    int res = 0;
    while (idx >= 1) {
        res += c[idx];
        idx -= lowbit(idx);
    }
    return res;
}
void solve (int l, int r, int beg, int end) {
    if (beg > end) return;
    if (l == r) {
        for (int i = beg; i <= end; ++ i)
            if (a[i].type == 1)
                res[a[i].idx] = l;
        return;
    }

    int mid = l + r >> 1;
    int lp = 0, rp = 0;
    for (int i = beg; i <= end; ++ i) {
        if (a[i].type == 0) {
            if (a[i].val <= mid) {
                add(a[i].idx, 1);
                la[++ lp] = a[i];
            } else {
                ra[++ rp] = a[i];
            }
        } else {
            int cnt = sum(a[i].r) - sum(a[i].l - 1);
            if (cnt >= a[i].k) {
                la[++ lp] = a[i];
            } else {
                a[i].k -= cnt;
                ra[++ rp] = a[i];
            }
        }
    }
    for (int i = end; i >= beg; -- i)
        if (a[i].type == 0 && a[i].val <= mid)
            add(a[i].idx, -1);
    for (int i = 1; i <= lp; ++ i)
        a[beg + i - 1] = la[i];
    for (int i = 1; i <= rp; ++ i)
        a[beg + lp + i - 1] = ra[i];
    
    solve(l, mid, beg, beg + lp - 1);
    solve(mid + 1, r, beg + lp, end);
}

int main () {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++ i) {
        a[i].type = 0;
        a[i].idx = i; // 1 ~
        scanf("%d", &a[i].val);
    }
    for (int i = n + 1; i <= n + m; ++ i) {
        a[i].type = 1;
        a[i].idx = i - n; // 1 ~
        scanf("%d%d%d", &a[i].l, &a[i].r, &a[i].k);
    }

    solve(-inf, inf, 1, n + m);
    for (int i = 1; i <= m; ++ i) printf("%d\n", res[i]);
    return 0;
}
最后修改于: