unitxt
unitxt copied to clipboard
Fusion does not assign value to field "group" of every instance resulting in errors in metrics computation
Fusion classes were suppose to add field named "group" to every instance of the fusion streams, stating the name of its origin. In turn in metric computation time the metric pipeline is splitting by groups (done here: https://github.com/IBM/unitxt/blob/4bdefa7fdf3a801fd214e07c9c89ec384fbe66b4/src/unitxt/metric_utils.py#L110) and calculate the metrics for each group seperatly based on the group metric.
This requires two changes:
- changing the
originsfield in the Fusion classes to dictionary mapping between origin names to origin stream. - Adding the name of the origin to every instance like here:
class FixedFusion(BaseFusion):
"""FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.
Args:
origins: List of SourceOperator objects.
examples_per_task: Number of examples per task. If None, all examples are returned.
splits: List of splits to include. If None, all splits are included.
"""
max_instances_per_origin: Optional[int] = None
def fusion_generator(self, split) -> Generator:
for origin_name, origin in self.origins.items():
multi_stream = origin()
if split not in multi_stream:
continue
iterator = iter(multi_stream[split])
if self.max_instances_per_origin is not None:
for _ in range(self.max_instances_per_origin):
try:
instance = next(iterator)
instance["group"] = origin_name
yield instance
except StopIteration:
break
else:
for instance in iterator:
instance["group"] = origin_name
yield instance
Then ofcourse we need to check it all works with tests.