Reinforcement-learning-with-tensorflow
Reinforcement-learning-with-tensorflow copied to clipboard
min_prob 永遠返回 0
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py#L114
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight
由於Sumtree 在剛開始的時候存在大量 0 所以 np.min 會返回 0 而導致
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
返回錯誤
有些人為了解決這個問題 在 min_prob 為 0 時改為 0.0001
https://blog.csdn.net/gsww404/article/details/103673852
但我認為也不是很好的做法
而且add entry時因要計算整個tree的max p和sample要計算整個tree的min p浪費了很多的資源 我還是堅持應該要從外部傳入p 從而得到更準確的機率, 減少更新延遲
以下是我整理過的代碼
import random
import numpy as np
class PrioritizedMemory(object): # stored as ( s, a, r, s_ ) in SumTree
e = 0.01
a = 0.6
beta = 0.4
beta_increment_per_sampling = 0.001
def __init__(self, capacity):
self.tree = SumTree(capacity)
self.capacity = capacity
def _get_priority(self, error):
return (np.abs(error) + self.e) ** self.a
def add(self, error, sample):
p = self._get_priority(error)
self.tree.add(p, sample)
def sample(self, n):
batch = []
idxs = []
segment = self.tree.total() / n
priorities = []
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
for i in range(n):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
(idx, p, data) = self.tree.get(s)
priorities.append(p)
batch.append(data)
idxs.append(idx)
sampling_probabilities = priorities / self.tree.total()
is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
is_weight /= is_weight.max()
return idxs, batch, is_weight
def batch_update(self, tree_idx, abs_errors):
for ti, e in zip(tree_idx, abs_errors):
p = self._get_priority(e)
self.tree.update(ti, p)
class SumTree:
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
self.n_entries = 0
def update(self, tree_idx, p):
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
# then propagate the change through tree
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def total(self):
return self.tree[0]
def add(self, p, data):
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data # update data_frame
self.update(tree_idx, p) # update tree_frame
self.data_pointer += 1
if self.data_pointer >= self.capacity: # replace when exceed the capacity
self.data_pointer = 0
if self.n_entries < self.capacity:
self.n_entries += 1
def get(self, v):
"""
Tree structure and array storage:
Tree index:
0 -> storing priority sum
/ \
1 2
/ \ / \
3 4 5 6 -> storing priority for transitions
Array type for storing:
[0,1,2,3,4,5,6]
"""
parent_idx = 0
while True: # the while loop is faster than the method in the reference code
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree): # reach bottom, end search
leaf_idx = parent_idx
break
else: # downward search, always search for a higher priority node
if v <= self.tree[cl_idx]:
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]