算法用于在一个未知长度的序列中随机取出某个或某些样本
其中每个样本被取出的概率均为 1/n
算法
假设数据序列的规模为 n,需要采样的数量的为 k。
首先构建一个可容纳 k 个元素的数组,将序列的前 k 个元素放入数组中。
然后从第 k+1 个元素开始,以 k/n 的概率来决定该元素最后是否被留在数组中(每进来一个新的元素,数组中的每个旧元素被替换的概率是相同的)。 当遍历完所有元素之后,数组中剩下的元素即为所需采取的样本。
证明
①
对于数组中第 i 个数据(i ≤ k)。在 k 步之前,被选中的概率为 1。
当第 k+1 步时,被第 k+1 个数据替换的概率 = 第k+1个元素被选中的概率 * 第i个数 被选中替换的概率,
即为
则其被保留(取到)的概率为
依次类推,在不被第 k + 1 个元素替换的前提下,不被第k+2 个数据替换的条件概率为
则运行到第 n 步时,被保留的概率:
②
对于第 j 个数据(j > k)。第 j个数据被选中的概率为 k / j。
不被第 j + 1 个元素替换的概率为
则运行到第 n步时,被保留的概率 = 被选中的概率 * 不被替换的概率,即条件概率的连乘)
代码实现
一些辅助函数
1 2 3 4 5 6 7 8 9 10 11 12
| float randomFloat() { return rand() / (RAND_MAX + 1.0f); }
int randomInt(int from, int until) { int i = (int) ((until - from) * randomFloat()) + from; return i; }
int randomInt(int until = RAND_MAX) { return randomInt(0, until); }
|
初始化样本数据
1 2 3 4
| srand(time(NULL)); for (int i = 0; i < N; ++i) { samples.push_back(randomInt(1000)); }
|
抽样过程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| auto bag = std::vector<int>(K);
int idx = 0; for (const auto &item: samples) { if (idx < K) { bag[idx++] = item; continue; }
if (randomInt(idx + 1) <= K) { bag[randomInt(K)] = item; } idx++; }
|
封装为类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| class ReservoirSampler { public: ReservoirSampler(int size) : bag(std::vector<int>(size)), size(size) {}
void update(int sample);
void forEach(const std::function<void(const int&)> &action);
std::vector<int> bag; private: int idx = 0; int size; };
void ReservoirSampler::update(int sample) { if (idx < size) { bag[idx++] = sample; return; } if (randomInt(idx + 1) <= size) { bag[randomInt(size)] = sample; } idx++; }
void ReservoirSampler::forEach(const std::function<void(const int &)> &action) { for (const int &item: bag) action(item); }
|
封装为类之后每次更新值的时候使用update成员函数即可
为支持多种类型,也可以使用泛型类