sentence_segmenter in metric should be adapt to language in adapt function?
[ ] I have checked the documentation and related resources and couldn't resolve my bug.
Describe the bug It's good that almost all metric in ragas can be adapt to other language, but find the adaptation of sentence_segmenter happens directly after initialization(in post_init) for metrics that uses sentence_segmenter.
But I guess the adaptation of sentence_segmenter should happen in self.adapt func, otherwise the language is still the initial one.
Code to Reproduce
class Faithfulness(MetricWithLLM):
name: str = "faithfulness" # type: ignore
evaluation_mode: EvaluationMode = EvaluationMode.qac # type: ignore
nli_statements_message: Prompt = field(
default_factory=lambda: NLI_STATEMENTS_MESSAGE
)
statement_prompt: Prompt = field(default_factory=lambda: LONG_FORM_ANSWER_PROMPT)
sentence_segmenter: t.Optional[HasSegmentMethod] = None
max_retries: int = 1
_reproducibility: int = 1
@property
def reproducibility(self):
return self._reproducibility
@reproducibility.setter
def reproducibility(self, value):
if value < 1:
logger.warning("reproducibility cannot be less than 1, setting to 1")
value = 1
elif value % 2 == 0:
logger.warning(
"reproducibility level cannot be set to even number, setting to odd"
)
value += 1
self._reproducibility = value
def __post_init__(self):
if self.sentence_segmenter is None:
language = self.nli_statements_message.language
self.sentence_segmenter = get_segmenter(language=language, clean=False)
def _create_nli_prompt(self, row: t.Dict, statements: t.List[str]) -> PromptValue:
assert self.llm is not None, "llm must be set to compute score"
contexts = row["contexts"]
# check if the statements are support in the contexts
contexts_str: str = "\n".join(contexts)
statements_str: str = json.dumps(statements)
prompt_value = self.nli_statements_message.format(
context=contexts_str, statements=statements_str
)
return prompt_value
def _create_statements_prompt(self, row: t.Dict) -> PromptValue:
assert self.sentence_segmenter is not None, "sentence_segmenter is not set"
text, question = row["answer"], row["question"]
sentences = self.sentence_segmenter.segment(text)
sentences = [
sentence for sentence in sentences if sentence.strip().endswith(".")
]
sentences = "\n".join([f"{i}:{x}" for i, x in enumerate(sentences)])
prompt_value = self.statement_prompt.format(
question=question, answer=text, sentences=sentences
)
return prompt_value
def _compute_score(self, answers: StatementFaithfulnessAnswers):
# check the verdicts and compute the score
faithful_statements = sum(
1 if answer.verdict else 0 for answer in answers.__root__
)
num_statements = len(answers.__root__)
if num_statements:
score = faithful_statements / num_statements
else:
logger.warning("No statements were generated from the answer.")
score = np.nan
return score
async def _ascore(
self: t.Self, row: t.Dict, callbacks: Callbacks, is_async: bool
) -> float:
"""
returns the NLI score for each (q, c, a) pair
"""
assert self.llm is not None, "LLM is not set"
p_value = self._create_statements_prompt(row)
statements = await self.llm.generate(
p_value,
callbacks=callbacks,
is_async=is_async,
)
statements = await _statements_output_parser.aparse(
statements.generations[0][0].text, p_value, self.llm, self.max_retries
)
if statements is None:
return np.nan
statements = [item["simpler_statements"] for item in statements.dicts()]
statements = [item for sublist in statements for item in sublist]
assert isinstance(statements, t.List), "statements must be a list"
p_value = self._create_nli_prompt(row, statements)
nli_result = await self.llm.generate(
p_value,
callbacks=callbacks,
is_async=is_async,
n=self._reproducibility,
)
nli_result_text = [
nli_result.generations[0][i].text for i in range(self._reproducibility)
]
faithfulness_list = [
await _faithfulness_output_parser.aparse(
text, p_value, self.llm, self.max_retries
)
for text in nli_result_text
]
faithfulness_list = [
faith.dicts() for faith in faithfulness_list if faith is not None
]
if faithfulness_list:
faithfulness_list = ensembler.from_discrete(
faithfulness_list,
"verdict",
)
faithfulness_list = StatementFaithfulnessAnswers.parse_obj(
faithfulness_list
)
else:
return np.nan
return self._compute_score(faithfulness_list)
def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
assert self.llm is not None, "LLM is not set"
logger.info(f"Adapting Faithfulness metric to {language}")
self.nli_statements_message = self.nli_statements_message.adapt(
language, self.llm, cache_dir
)
self.statement_prompt = self.statement_prompt.adapt(
language, self.llm, cache_dir
)
def save(self, cache_dir: t.Optional[str] = None) -> None:
self.nli_statements_message.save(cache_dir)
@jmgu0212 Hello there! I'm here to assist you with any bugs, questions, or contribution inquiries.
To ensure the sentence_segmenter is adapted to the correct language within the adapt function, you can modify the adapt method of the Faithfulness class to include the adaptation of the sentence_segmenter. Here is how you can do it:
- Add the
sentence_segmenteradaptation logic to theadaptmethod. - Ensure the
sentence_segmenteris updated to the specified language.
Here is the updated code for the Faithfulness class:
class Faithfulness(MetricWithLLM):
name: str = "faithfulness" # type: ignore
evaluation_mode: EvaluationMode = EvaluationMode.qac # type: ignore
nli_statements_message: Prompt = field(
default_factory=lambda: NLI_STATEMENTS_MESSAGE
)
statement_prompt: Prompt = field(default_factory=lambda: LONG_FORM_ANSWER_PROMPT)
sentence_segmenter: t.Optional[HasSegmentMethod] = None
max_retries: int = 1
_reproducibility: int = 1
@property
def reproducibility(self):
return self._reproducibility
@reproducibility.setter
def reproducibility(self, value):
if value < 1:
logger.warning("reproducibility cannot be less than 1, setting to 1")
value = 1
elif value % 2 == 0:
logger.warning(
"reproducibility level cannot be set to even number, setting to odd"
)
value += 1
self._reproducibility = value
def __post_init__(self):
if self.sentence_segmenter is None:
language = self.nli_statements_message.language
self.sentence_segmenter = get_segmenter(language=language, clean=False)
def _create_nli_prompt(self, row: t.Dict, statements: t.List[str]) -> PromptValue:
assert self.llm is not None, "llm must be set to compute score"
contexts = row["contexts"]
contexts_str: str = "\n".join(contexts)
statements_str: str = json.dumps(statements)
prompt_value = self.nli_statements_message.format(
context=contexts_str, statements=statements_str
)
return prompt_value
def _create_statements_prompt(self, row: t.Dict) -> PromptValue:
assert self.sentence_segmenter is not None, "sentence_segmenter is not set"
text, question = row["answer"], row["question"]
sentences = self.sentence_segmenter.segment(text)
sentences = [
sentence for sentence in sentences if sentence.strip().endswith(".")
]
sentences = "\n".join([f"{i}:{x}" for i, x in enumerate(sentences)])
prompt_value = self.statement_prompt.format(
question=question, answer=text, sentences=sentences
)
return prompt_value
def _compute_score(self, answers: StatementFaithfulnessAnswers):
faithful_statements = sum(
1 if answer.verdict else 0 for answer in answers.__root__
)
num_statements = len(answers.__root__)
if num_statements:
score = faithful_statements / num_statements
else:
logger.warning("No statements were generated from the answer.")
score = np.nan
return score
async def _ascore(
self: t.Self, row: t.Dict, callbacks: Callbacks, is_async: bool
) -> float:
assert self.llm is not None, "LLM is not set"
p_value = self._create_statements_prompt(row)
statements = await self.llm.generate(
p_value,
callbacks=callbacks,
is_async=is_async,
)
statements = await _statements_output_parser.aparse(
statements.generations[0][0].text, p_value, self.llm, self.max_retries
)
if statements is None:
return np.nan
statements = [item["simpler_statements"] for item in statements.dicts()]
statements = [item for sublist in statements for item in sublist]
assert isinstance(statements, t.List), "statements must be a list"
p_value = self._create_nli_prompt(row, statements)
nli_result = await self.llm.generate(
p_value,
callbacks=callbacks,
is_async=is_async,
n=self._reproducibility,
)
nli_result_text = [
nli_result.generations[0][i].text for i in range(self._reproducibility)
]
faithfulness_list = [
await _faithfulness_output_parser.aparse(
text, p_value, self.llm, self.max_retries
)
for text in nli_result_text
]
faithfulness_list = [
faith.dicts() for faith in faithfulness_list if faith is not None
]
if faithfulness_list:
faithfulness_list = ensembler.from_discrete(
faithfulness_list,
"verdict",
)
faithfulness_list = StatementFaithfulnessAnswers.parse_obj(
faithfulness_list
)
else:
return np.nan
return self._compute_score(faithfulness_list)
def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
assert self.llm is not None, "LLM is not set"
logger.info(f"Adapting Faithfulness metric to {language}")
self.nli_statements_message = self.nli_statements_message.adapt(
language, self.llm, cache_dir
)
self.statement_prompt = self.statement_prompt.adapt(
language, self.llm, cache_dir
)
if self.sentence_segmenter is not None:
self.sentence_segmenter = get_segmenter(language=language, clean=False)
def save(self, cache_dir: t.Optional[str] = None) -> None:
self.nli_statements_message.save(cache_dir)
In this updated code, the adapt method now includes logic to adapt the sentence_segmenter to the specified language. This ensures that the sentence_segmenter is correctly updated when the adapt method is called [1].
To continue talking to Dosu, mention @dosu.
keep track of #890 for updates
This has been fixed with v0.2 - I know finally 😅 🎉
do checkout the docs here: https://docs.ragas.io/en/stable/howtos/customizations/metrics/_metrics_language_adaptation/ reference here: https://docs.ragas.io/en/stable/references/prompt/#ragas.prompt.PromptMixin
and if you're migrating from v0.1 check out the migration docs here: https://docs.ragas.io/en/stable/howtos/migrations/migrate_from_v01_to_v02
could you check it out and verify if not feel free to comment here and I'll help you out - really sorry again that it tool this while