0/1 分数规划是一种常见的模型:给你 n 个价值 $a_i$ 与 n 个代价 $b_i$,让你选出 m 个数字,使得 $ \sum \frac {a_i} {b_i} $ 最大。显然这种题目可以用二分,但是有一种更优秀的方法:Dinkelbach 迭代法。
先来看一道裸的例题:
例题:最大化平均值
题目描述
有 n 个物品的重量和价值分别是 $ w_i $ 和 $ v_i$。从中选出 k 个物品使得单位重量的价值最大。
输入输出格式
输入格式
第一行输入两个整数 n,k ;
第二行输入 n 个整数 $w_i$ ;
第三行输入 n 个整数 $v_i$;
输出格式
输出一个浮点数,保留两位小数。
样例输入输出
样例输入
3 2
2 5 2
2 3 1
样例输出
0.75
约定
$ 1<=k<=n<=500000,1<=wi,vi<=10^6 $
直接二分解法
显然题目的意思是:我们需要找到 k 个物品使得 $\displaystyle \frac {v_1+v_2+\dots+v_k} {w_1+w_2+\dots+w_k}$ 最大。假设当前二分 ans,要让 ans 逼近答案,那么:
$$ \displaystyle \frac {v_1+v_2+\dots+v_k} {w_1+w_2+\dots+w_k} \geqslant ans \\ \iff \sum_{i=1}^{k} (v_i-ans \ast w_i) \geqslant 0 $$
显然左边是个一次函数。我们只要二分枚举 ans,造出左边所有 $v_i-ans\ast w_i$ 排序取前 k 大就可以了,复杂度大约是 $\Theta (N \ast log_2 N \ast log_2 sum)$。
参考代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=10005;
int n,m,w[maxn],v[maxn];
double ans,q[maxn];
inline bool cmp(double x,double y){
return x>y;
}
inline bool check(double x){
for (int i=1;i<=n;i++) q[i]=(double)v[i]-(double)w[i]*x;
sort(q+1,q+1+n,cmp);
double cnt=0.0;
for (int i=1;i<=m;i++) cnt+=q[i];
return cnt>=0.0;
}
int main(){
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&w[i]);
for (int i=1;i<=n;i++) scanf("%d",&v[i]);
double L=1e-8,R=1e8;
while (R-L>1e-5){
double mid=(L+R)/2.0;
if (check(mid)){
ans=mid;L=mid+1e-5;
} else R=mid-1e-5;
}
printf("%.2f\n",ans);
return 0;
}
Dinkelbach 迭代法
仔细分析,二分实际上是盲目的。它只关心答案在那边,然后每次砍掉一半。当我们二分一个答案的时候,有可能这个时候已经求出了更优的解,我们却没有去理它,而是继续我们的二分。所以这时候就有了 Dinkelbach 迭代法。
上面这段话意思是,假设我们求出了当前的一个 ans 并且满足了上面推出来的那个奇奇怪怪的式子,那么我们就知道我们要选的 k 个物品,那么我们可以算出选这 k 个物品实际的解,这个解有可能比二分枚出来的 ans 更优秀!所以我们直接把这个解再作为 ans 验证,这样就可以更快地逼近答案了~
(如果不懂,看看代码就理解了)
参考代码
// Dinkelbach Iterative ALgorithm
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=500005;
const double eps=1e-7;
int n,m,w[maxn],v[maxn];
struct TempData{
double x;
int id;
}a[maxn];
double ans=1e7;
inline double myabs(double x){
if (x<0) return -x; else return x;
}
inline int read(){
int ret=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
return ret*f;
}
inline bool cmp(TempData aa,TempData bb){
return aa.x>bb.x;
}
int main(){
n=read();m=read();
for (int i=1;i<=n;i++) w[i]=read();
for (int i=1;i<=n;i++) v[i]=read();
double now=(double)(rand()%n);
while (myabs(now-ans)>eps){
ans=now;
for (int i=1;i<=n;i++) a[i].x=(double)v[i]-(double)w[i]*ans,a[i].id=i;
sort(a+1,a+1+n,cmp);
double sum_w=0.0,sum_v=0.0;
for (int i=1;i<=m;i++){
sum_w+=w[a[i].id];
sum_v+=v[a[i].id];
}
now=sum_v/sum_w;
}
printf("%.2f\n",ans);
return 0;
}