【CF739E】Gosha is hunting
题意:有n个小精灵,你有a个普通球和b个超级球,用普通球抓住第i只小精灵的概率为$A_i$,用超级球抓住第i只小精灵的概率为$u_i$。你必须一开始就决定向哪些精灵投掷哪些精灵球,同种的球只能对一个精灵用一次,可以对一只精灵投掷两种球,如果两次中有一次抓到则视为抓到。问你如果采用最优的方案,最终抓到小精灵的期望个数是多少。
$n\le 2000$。
题解:我们先将所有小精灵按$B$排序,然后我们枚举最后一个投b或ab的小精灵i,那么不难证明i左边的所有小精灵都是b或a或ab,i右面的小精灵都是0或a。接着我们想把左面的三种情况拆开,不难发现$A_x+B_x-A_xB_x+B_y>B_x+A_y+B_y-A_yB_y$->$(1-B_x)A_x>(1-B_y)A_y$,所以只要将i左边按$(1-B)A$排序,然后就可以枚举j,满足[1,j]都是ab或b,(j,i]都是a或b。此时我们就可以先假设[1,i]全选b,则[1,j]中每个点选ab的贡献就是$A-AB$,(j,i]中每个点选a的贡献就是$A-B$,(i,n]中每个点选a的贡献是$A$。我们只需要用一个数据结构维护前k大值的和即可。用treap比较容易,当然我懒,用的是两个对顶的堆来维护。
时间复杂度$O(n^2\log n)$。
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #include <queue> #define lson x<<1 #define rson x<<1|1 using namespace std; const int maxn=2010; int n,A,B; double ans,sum; struct node { double a,b,c; }p[maxn]; struct heap { priority_queue<double> a,b; inline double top() { while(!b.empty()&&a.top()==b.top()) a.pop(),b.pop(); return a.top(); } inline int size() {return a.size()-b.size();} inline void erase(double x) {b.push(x);} inline void push(double x) {a.push(x);} inline void pop() { while(!b.empty()&&a.top()==b.top()) a.pop(),b.pop(); a.pop(); } inline void clr() { while(!a.empty()) a.pop(); while(!b.empty()) b.pop(); } }; struct bst { heap p1,p2; int lim; inline void insert(double x) { p1.push(-x),sum+=x; if(p1.size()>lim) p2.push(-p1.top()),sum+=p1.top(),p1.pop(); } inline void del(double x) { if(x<=p2.top()) p2.erase(x); else { sum-=x,p1.erase(-x); if(p1.size()<lim&&p2.size()) p1.push(-p2.top()),sum+=p2.top(),p2.pop(); } } inline void clr() {p1.clr(),p2.clr();} }b1,b2; bool cmp1(const node &a,const node &b) { return a.b>b.b; } bool cmp2(const node &a,const node &b) { return (1-a.a)*a.b>(1-b.a)*b.b; } int main() { scanf("%d%d%d",&n,&A,&B); int i,j; for(i=1;i<=n;i++) scanf("%lf",&p[i].a); for(i=1;i<=n;i++) scanf("%lf",&p[i].b),p[i].c=1-(1-p[i].a)*(1-p[i].b); sort(p+1,p+n+1,cmp1); double sumb=0; for(i=1;i<B;i++) sumb+=p[i].b; for(i=B;i<=min(n,A+B);i++) { sumb+=p[i].b; b1.clr(),b2.clr(),b1.lim=A-i+B,b2.lim=i-B,sum=0; sort(p+1,p+i+1,cmp2); for(j=1;j<=i;j++) b2.insert(p[j].a-p[j].b); for(j=i+1;j<=n;j++) b1.insert(p[j].a); ans=max(ans,sumb+sum); for(j=1;j<=B;j++) { b2.del(p[j].a-p[j].b),b1.insert(p[j].c-p[j].b); ans=max(ans,sumb+sum); } } printf("%.6lf",ans); return 0; }