liam0205.github.io
liam0205.github.io copied to clipboard
Alias Method: 在常数时间复杂度内非均匀地随机抽取元素 | 始终
https://liam.page/2019/12/02/non-uniform-random-choice-in-constant-time-complexity/
不忘初心,方得始终。
用 Python 实现了一下,同时实现了有放回的选取和无放回的采样方法。供后来者参考。
import random
def create_alias_table(weights):
"""
构建 alias table
参数:
weights: 对应的正权重列表
返回:
prob: 概率表,列表长度等于 population 长度
alias: 别名表,与 prob 配套使用
"""
n = len(weights)
total_weight = sum(weights)
# 将权重归一化后放大 n 倍
scaled_prob = [w * n / total_weight for w in weights]
prob = [0.0] * n
alias = [0] * n
small = []
large = []
for i, sp in enumerate(scaled_prob):
if sp < 1.0:
small.append(i)
else:
large.append(i)
# 归类处理:让每个位置的概率都趋近于 1
while small and large:
small_idx = small.pop()
large_idx = large.pop()
prob[small_idx] = scaled_prob[small_idx]
alias[small_idx] = large_idx
# 调整 large_idx 的概率
scaled_prob[large_idx] = scaled_prob[large_idx] - (1.0 - scaled_prob[small_idx])
if scaled_prob[large_idx] < 1.0:
small.append(large_idx)
else:
large.append(large_idx)
# 剩余的设置概率为 1
for idx in large + small:
prob[idx] = 1.0
alias[idx] = idx
return prob, alias
def alias_choice(population, weights, k=1):
"""
使用 alias method 从 population 中采样 k 个元素(有放回采样)
参数:
population: 元素列表
weights: 对应的正权重列表
k: 采样次数
返回:
choices: 采样得到的结果列表
"""
if len(population) != len(weights) or not population:
raise ValueError("population 和 weights 必须长度一致且非空")
prob, alias = create_alias_table(weights)
n = len(population)
choices = []
for _ in range(k):
i = random.randint(0, n - 1)
# 根据 prob 判断使用 i 或者使用 alias[i]
if random.random() < prob[i]:
choices.append(population[i])
else:
choices.append(population[alias[i]])
return choices
def alias_sample_once(population, weights):
"""
根据给定 population 和 weights(权重对应顺序)使用 alias method 完成一次有放回采样
参数:
population: 元素列表
weights: 对应的正权重列表
返回:
采样到的一个元素
"""
n = len(population)
prob, alias = create_alias_table(weights)
idx = random.randint(0, n - 1)
if random.random() < prob[idx]:
return idx
else:
return alias[idx]
def alias_sample(population, weights, k):
"""
基于 alias method 的无放回采样实现
参数:
population: 元素列表
weights: 对应的正权重列表
k: 采样个数,必须不超过 population 的长度
返回:
samples: 采样得到的元素列表
"""
if len(population) != len(weights) or not population:
raise ValueError("population 和 weights 必须长度一致且非空")
if k > len(population):
raise ValueError("采样数量 k 不能超过候选集长度")
elif k == len(population):
return population
# 拷贝候选集和权重,保证原数据不变
pop_copy = population[:]
weight_copy = weights[:]
samples = []
for _ in range(k):
n = len(pop_copy)
# 使用 alias method 采样一次,得到的是 pop_copy 的索引
chosen_idx = alias_sample_once(pop_copy, weight_copy)
samples.append(pop_copy[chosen_idx])
# 移除已采样的元素及其权重
del pop_copy[chosen_idx]
del weight_copy[chosen_idx]
return samples
if __name__ == "__main__":
population = ['a', 'b', 'c', 'd']
weights = [10, 5, 1, 1]
choices = alias_choice(population, weights, k=10)
print("采样结果(有放回):", choices)
samples = alias_sample(population, weights, k=3)
print("采样结果(无放回):", samples)