vanna
vanna copied to clipboard
Add validation to `vn.train()` inputs
This PR adds another input argument checking to vn.train()
. If more than one one input is used (question and sql
, documentation
, ddl
, plan
), a validation error will be raised that notifies the user.
The Problem
- If more than one argument is sent to
vn.train()
, only one item will be sent to the database for later referencing - The user wont know which one did, and since no warning is raised, won't know their mistake.
Context
-
vn.train()
accepts multiple kwargs for each training item, which are by default None. (question and sql
,documentation
,ddl
,plan
). - If
sql
andquestion
are not both None, or not None, aValidationError
exception is raised. - If (
sql
,question
) and any other ofdocumentation
,ddl
,plan
, are given, only one will be added, in this order:documentation
,sql
&question
,ddl
,plan
(which allows for all or some).
vn = VannaDefault()
vn.train(ddl="[EmployeeId] INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
[LastName] NVARCHAR(20) NOT NULL,",
documentation="The 'employees' table contains the FirstName and LastName of employees")
# only documentation would be added
The solution
- Refactored out the current Validation, and added another check for when more than one
is not None
.
class VannaBase:
...
@staticmethod
def _validate_train_args(
question: Optional[str] = None,
sql: Optional[str] = None,
ddl: Optional[str] = None,
documentation: Optional[str] = None,
plan: Optional[TrainingPlan] = None,
) -> None:
# Implementation here ...
...
def train(
self,
question: str = None,
sql: str = None,
ddl: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
...
self._validate_train_args(question, sql, ddl, documentation, plan)
Tests;
- Followed steps as per
CONTRIBUTING.md
. Tests pass.