0/1 分数规划是一种常见的模型:给你 n 个价值 $a_i$ 与 n 个代价 $b_i$,让你选出 m 个数字,使得 $ \sum \frac {a_i} {b_i} $ 最大。显然这种题目可以用二分,但是有一种更优秀的方法:Dinkelbach 迭代法。

先来看一道裸的例题:

例题:最大化平均值

题目描述

有 n 个物品的重量和价值分别是 $ w_i $ 和 $ v_i$。从中选出 k 个物品使得单位重量的价值最大。

输入输出格式

输入格式

第一行输入两个整数 n,k ;
第二行输入 n 个整数 $w_i$ ;
第三行输入 n 个整数 $v_i$;

输出格式

输出一个浮点数,保留两位小数。

样例输入输出

样例输入

3 2
2 5 2
2 3 1

样例输出

0.75

约定

$ 1<=k<=n<=500000,1<=wi,vi<=10^6 $

直接二分解法

显然题目的意思是:我们需要找到 k 个物品使得 $\displaystyle \frac {v_1+v_2+\dots+v_k} {w_1+w_2+\dots+w_k}$ 最大。假设当前二分 ans,要让 ans 逼近答案,那么:

$$ \displaystyle \frac {v_1+v_2+\dots+v_k} {w_1+w_2+\dots+w_k} \geqslant ans \\ \iff \sum_{i=1}^{k} (v_i-ans \ast w_i) \geqslant 0 $$

显然左边是个一次函数。我们只要二分枚举 ans,造出左边所有 $v_i-ans\ast w_i$ 排序取前 k 大就可以了,复杂度大约是 $\Theta (N \ast log_2 N \ast log_2 sum)$。

参考代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=10005;
int n,m,w[maxn],v[maxn];
double ans,q[maxn];
inline bool cmp(double x,double y){
    return x>y;
}
inline bool check(double x){
    for (int i=1;i<=n;i++) q[i]=(double)v[i]-(double)w[i]*x;
    sort(q+1,q+1+n,cmp);
    double cnt=0.0;
    for (int i=1;i<=m;i++) cnt+=q[i];
    return cnt>=0.0;
}
int main(){
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%d",&w[i]);
    for (int i=1;i<=n;i++) scanf("%d",&v[i]);
    double L=1e-8,R=1e8;
    while (R-L>1e-5){
        double mid=(L+R)/2.0;
        if (check(mid)){
            ans=mid;L=mid+1e-5;
        } else R=mid-1e-5;
    }
    printf("%.2f\n",ans);
    return 0;
}

Dinkelbach 迭代法

仔细分析,二分实际上是盲目的。它只关心答案在那边,然后每次砍掉一半。当我们二分一个答案的时候,有可能这个时候已经求出了更优的解,我们却没有去理它,而是继续我们的二分。所以这时候就有了 Dinkelbach 迭代法。

上面这段话意思是,假设我们求出了当前的一个 ans 并且满足了上面推出来的那个奇奇怪怪的式子,那么我们就知道我们要选的 k 个物品,那么我们可以算出选这 k 个物品实际的解,这个解有可能比二分枚出来的 ans 更优秀!所以我们直接把这个解再作为 ans 验证,这样就可以更快地逼近答案了~

(如果不懂,看看代码就理解了)

参考代码

// Dinkelbach Iterative ALgorithm
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=500005;
const double eps=1e-7;
int n,m,w[maxn],v[maxn];
struct TempData{
    double x;
    int id;
}a[maxn];
double ans=1e7;
inline double myabs(double x){
    if (x<0) return -x; else return x;
}
inline int read(){
    int ret=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
    return ret*f;
}
inline bool cmp(TempData aa,TempData bb){
    return aa.x>bb.x;
}
int main(){
    n=read();m=read();
    for (int i=1;i<=n;i++) w[i]=read();
    for (int i=1;i<=n;i++) v[i]=read();
    double now=(double)(rand()%n);
    while (myabs(now-ans)>eps){
        ans=now;
        for (int i=1;i<=n;i++) a[i].x=(double)v[i]-(double)w[i]*ans,a[i].id=i;
        sort(a+1,a+1+n,cmp);
        double sum_w=0.0,sum_v=0.0;
        for (int i=1;i<=m;i++){
            sum_w+=w[a[i].id];
            sum_v+=v[a[i].id];
        }
        now=sum_v/sum_w;
    }
    printf("%.2f\n",ans);
    return 0;
}