Static values explanation
Greetings,
I am trying to implmenet my own custom dataset in unified format and eventually train a PieplineAgent using my DST implementation, a DQN Policy and a TemplateNLG. I saw that the current implementation of the DQN only supports multiwoz21. As such, I started to make a few modifications to suit my needs and found out I need to implement a custom Vector as well. Looking at VectorBinary, I see this
def get_state_dim(self):
self.belief_state_dim = 0
for domain in self.ontology['state']:
for slot in self.ontology['state'][domain]:
self.belief_state_dim += 1
self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
len(self.db_domains) + 6 * len(self.db_domains) + 1
I do not understand the last 3 factors to the sum. What is the intuiton behind 6* len(self.db_domains) + 1 ?
Thank you for your time !
Hello!
https://github.com/ConvLab/ConvLab-3/blob/master/convlab/policy/vector/vector_binary.py#L61
In this line you can see how the vectorized state is assembled together. The "+1" you mention is a binary telling whether the dialogue has terminated or not.
The "6 * len(self.db_domains)" is a feature vector that encodes for every database domain how many entities were found (for each domain, a 6-dimensional vector encodes the number of database results)
I hope that helps!
Thank @ChrisGeishauser . Yes, the "6 * len(self.db_domains)" indicates how many results are found in the database. By the way, I'm unsure if DQN works well. It's a community implementation (https://github.com/thu-coai/ConvLab-2/pull/113) [discussion]. If DQN does not work well, you can try policy-based models.