fishtest icon indicating copy to clipboard operation
fishtest copied to clipboard

Type annotations in Fishtest

Open vdbergh opened this issue 1 year ago • 6 comments

To stay up to date I adapted vtjson to work well with type annotations. This would be the new runs schema:


import copy
import math
from datetime import datetime, timezone
from typing import Annotated, Literal, NotRequired, TypedDict

from bson.objectid import ObjectId

from vtjson import (
    at_most_one_of,
    div,
    fields,
    ge,
    glob,
    gt,
    ifthen,
    intersect,
    ip_address,
    keys,
    lax,
    one_of,
    quote,
    regex,
    skip_first,
    url,
)

username = Annotated[str, regex(r"[!-~][ -~]{0,30}[!-~]", name="username"), skip_first]
net_name = Annotated[str, regex("nn-[a-f0-9]{12}.nnue", name="net_name"), skip_first]
tc = Annotated[
    str, regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc"), skip_first
]
str_int = Annotated[str, regex(r"[1-9]\d*", name="str_int"), skip_first]
sha = Annotated[str, regex(r"[a-f0-9]{40}", name="sha"), skip_first]
country_code = Annotated[str, regex(r"[A-Z][A-Z]", name="country_code"), skip_first]
run_id = Annotated[str, ObjectId.is_valid]
uuid = Annotated[
    str,
    regex(r"[0-9a-zA-Z]{2,}(-[a-f0-9]{4}){3}-[a-f0-9]{12}", name="uuid"),
    skip_first,
]
epd_file = Annotated[str, glob("*.epd", name="epd_file"), skip_first]
pgn_file = Annotated[str, glob("*.pgn", name="pgn_file"), skip_first]
even = Annotated[int, div(2, name="even"), skip_first]
datetime_utc = Annotated[datetime, fields({"tzinfo": timezone.utc})]

uint = Annotated[int, ge(0)]
suint = Annotated[int, gt(0)]
ufloat = Annotated[float, ge(0)]
sufloat = Annotated[float, gt(0)]


class results_type(TypedDict):
    wins: uint
    losses: uint
    draws: uint
    crashes: uint
    time_losses: uint
    pentanomial: Annotated[list[int], [uint, uint, uint, uint, uint], skip_first]


def valid_results(R: results_type) -> bool:
    l, d, w = R["losses"], R["draws"], R["wins"]
    Rp = R["pentanomial"]
    return (
        l + d + w == 2 * sum(Rp)
        and w - l == 2 * Rp[4] + Rp[3] - Rp[1] - 2 * Rp[0]
        and Rp[3] + 2 * Rp[2] + Rp[1] >= d >= Rp[3] + Rp[1]
    )


results_schema = Annotated[
    results_type,
    valid_results,
]


class worker_info_schema(TypedDict):
    uname: str
    architecture: Annotated[list[str], [str, str], skip_first]
    concurrency: suint
    max_memory: uint
    min_threads: suint
    username: str
    version: uint
    python_version: Annotated[list[int], [uint, uint, uint], skip_first]
    gcc_version: Annotated[list[int], [uint, uint, uint], skip_first]
    compiler: Literal["clang++", "g++"]
    unique_key: uuid
    modified: bool
    ARCH: str
    nps: ufloat
    near_github_api_limit: bool
    remote_addr: Annotated[str, ip_address]
    country_code: country_code | Literal["?"]


class overshoot_type(TypedDict):
    last_update: uint
    skipped_updates: uint
    ref0: float
    m0: float
    sq0: ufloat
    ref1: float
    m1: float
    sq1: ufloat


class sprt_type(TypedDict):
    alpha: Annotated[float, 0.05, skip_first]
    beta: Annotated[float, 0.05, skip_first]
    elo0: float
    elo1: float
    elo_model: Literal["normalized"]
    state: Literal["", "accepted", "rejected"]
    llr: float
    batch_size: suint
    lower_bound: Annotated[float, -math.log(19), skip_first]
    upper_bound: Annotated[float, math.log(19), skip_first]
    lost_samples: NotRequired[uint]
    illegal_update: NotRequired[uint]
    overshoot: NotRequired[overshoot_type]


sprt_schema = Annotated[
    sprt_type,
    one_of("overshoot", "lost_samples"),
]


class param_schema(TypedDict):
    name: str
    start: float
    min: float
    max: float
    c_end: sufloat
    r_end: ufloat
    c: sufloat
    a_end: ufloat
    a: ufloat
    theta: float


class param_history_schema(TypedDict):
    theta: float
    R: ufloat
    c: ufloat


class spsa_schema(TypedDict):
    A: ufloat
    alpha: ufloat
    gamma: ufloat
    raw_params: str
    iter: uint
    num_iter: uint
    params: list[param_schema]
    param_history: NotRequired[list[list[param_history_schema]]]


class args_type(TypedDict):
    base_tag: str
    new_tag: str
    base_nets: list[net_name]
    new_nets: list[net_name]
    num_games: Annotated[uint, even]
    tc: tc
    new_tc: tc
    book: epd_file | pgn_file
    book_depth: str_int
    threads: suint
    resolved_base: sha
    resolved_new: sha
    master_sha: sha
    official_master_sha: sha
    msg_base: str
    msg_new: str
    base_options: str
    new_options: str
    info: str
    base_signature: str_int
    new_signature: str_int
    username: username
    tests_repo: Annotated[str, url, skip_first]
    auto_purge: bool
    throughput: ufloat
    itp: ufloat
    priority: float
    adjudication: bool
    sprt: NotRequired[sprt_schema]
    spsa: NotRequired[spsa_schema]


args_schema = Annotated[
    args_type,
    at_most_one_of("sprt", "spsa"),
]


class task_type(TypedDict):
    num_games: Annotated[uint, even]
    active: bool
    last_updated: datetime_utc
    start: uint
    residual: float
    residual_color: NotRequired[str]
    bad: NotRequired[Literal[True]]
    stats: results_schema
    worker_info: worker_info_schema


zero_results: results_type = {
    "wins": 0,
    "draws": 0,
    "losses": 0,
    "crashes": 0,
    "time_losses": 0,
    "pentanomial": 5 * [0],
}

if_bad_then_zero_stats_and_not_active = ifthen(
    keys("bad"), lax({"active": False, "stats": quote(zero_results)})
)

task_schema = Annotated[
    task_type,
    if_bad_then_zero_stats_and_not_active,
]


class bad_task_schema(TypedDict):
    num_games: Annotated[uint, even]
    active: Literal[False]
    last_updated: datetime_utc
    start: uint
    residual: float
    residual_color: str
    bad: Literal[True]
    task_id: uint
    stats: results_schema
    worker_info: worker_info_schema


class results_info_schema(TypedDict):
    style: str
    info: list[str]


class runs_type(TypedDict):
    _id: NotRequired[ObjectId]
    version: uint
    start_time: datetime_utc
    last_updated: datetime_utc
    tc_base: ufloat
    base_same_as_master: bool
    rescheduled_from: NotRequired[run_id]
    approved: bool
    approver: username | Literal[""]
    finished: bool
    deleted: bool
    failed: bool
    is_green: bool
    is_yellow: bool
    workers: uint
    cores: uint
    results: results_schema
    results_info: NotRequired[results_info_schema]
    args: args_schema
    tasks: list[task_schema]
    bad_tasks: NotRequired[list[bad_task_schema]]


def final_results_must_match(run: runs_type) -> bool:
    rr = copy.deepcopy(zero_results)
    for t in run["tasks"]:
        r = t["stats"]
        # mypy does not support variable keys for
        # TypedDict
        rr["wins"] += r["wins"]
        rr["losses"] += r["losses"]
        rr["draws"] += r["draws"]
        rr["crashes"] += r["crashes"]
        rr["time_losses"] += r["time_losses"]
        for i, p in enumerate(r["pentanomial"]):
            rr["pentanomial"][i] += p
    if rr != run["results"]:
        raise Exception(
            f"The final results {run['results']} do not match the computed results {rr}"
        )
    else:
        return True


def cores_must_match(run: runs_type) -> bool:
    cores = 0
    for t in run["tasks"]:
        if t["active"]:
            cores += t["worker_info"]["concurrency"]
    if cores != run["cores"]:
        raise Exception(
            f"Cores mismatch. Cores from tasks: {cores}. Cores from "
            f"run: {run['cores']}"
        )

    return True


def workers_must_match(run: runs_type) -> bool:
    workers = 0
    for t in run["tasks"]:
        if t["active"]:
            workers += 1
    if workers != run["workers"]:
        raise Exception(
            f"Workers mismatch. Workers from tasks: {workers}. Workers from "
            f"run: {run['workers']}"
        )

    return True


valid_aggregated_data = intersect(
    final_results_must_match,
    cores_must_match,
    workers_must_match,
)

runs_schema = Annotated[
    runs_type,
    lax(ifthen({"approved": True}, {"approver": username}, {"approver": ""})),
    lax(ifthen({"is_green": True}, {"is_yellow": False})),
    lax(ifthen({"is_yellow": True}, {"is_green": False})),
    lax(ifthen({"failed": True}, {"finished": True})),
    lax(ifthen({"deleted": True}, {"finished": True})),
    lax(ifthen({"finished": True}, {"workers": 0, "cores": 0})),
    lax(ifthen({"finished": True}, {"tasks": [{"active": False}, ...]})),
    valid_aggregated_data,
]

vdbergh avatar Nov 09 '24 21:11 vdbergh

I have now created comprehensive documentation for vtjson. See https://www.cantate.be/vtjson/ (canonical reference) or https://vtjson.readthedocs.io (if you don't mind some ads).

vdbergh avatar Dec 04 '24 06:12 vdbergh

Curiosities:

  • the example code uses snake case instead than camel case for the class naming
  • the classes are used only for the validation or are used in the code replacing the dictionaries

ppigazzini avatar Jan 06 '25 20:01 ppigazzini

The classes are used both for type checking and for validation. TypedDict is a standard python typing type for typed dictionaries.

class Foo(TypedDict):
   baz: int
   baz2: str

is functionally equivalent to

{"baz": int, "baz2": str} 

as far as vtjson is concerned. But the class Foo can also be used by static type checkers such as mypy for compile time validation

The point is that if we get some untyped json boo_untyped from an api then we can write

boo_typed = safe_cast(Foo, boo_untyped)

This will accomplish two things:

  • boo_untyped will be checked at run time that it really correspond to the schema Foo.
  • mypy will assign the type Foo to boo_typed and use it for further static type checking.

vdbergh avatar Jan 06 '25 20:01 vdbergh

the example code uses snake case instead than camel case for the class naming

I used lower case for the classes since they are really vtjson schemas in disguise and currently the schemas are written in lower case... I don't feel strongly about this.

vdbergh avatar Jan 06 '25 20:01 vdbergh

Very interesting! I can use Annotated in my fastapi projects to validate the data as well (too lazy to read any advanced use of Pydantic until now).

from typing import Annotated

from pydantic import (
    AnyUrl,
    BaseModel,
    EmailStr,
    Field,
    IPvAnyAddress,
    ValidationError,
    conint,
    conlist,
    constr,
)


class AddressModel(BaseModel):
    street: Annotated[constr(pattern=r"^\d+\s[A-Za-z\s\.]+$"), Field()]
    city: Annotated[constr(pattern=r"^[A-Za-z\s]+$"), Field()]
    zipcode: Annotated[constr(pattern=r"^\d{5}$"), Field()]


class UserModel(BaseModel):
    name: Annotated[constr(min_length=2), Field()]
    age: Annotated[conint(strict=True, gt=0, lt=120), Field()]
    address: AddressModel
    email: EmailStr
    urls: list[AnyUrl] | None = None
    ips: Annotated[conlist(IPvAnyAddress, min_length=2, max_length=2), Field()]


class NestedDictModel(BaseModel):
    user: UserModel
    other_key: str


def validate_data(data: dict) -> dict:
    try:
        validated_data = NestedDictModel(**data)
        return validated_data.model_dump()
    except ValidationError as e:
        print(e.json())
        raise


# Example usage
good_data = {
    "user": {
        "name": "John Doe",
        "age": 40,
        "address": {"street": "123 Main St.", "city": "Anytown", "zipcode": "12345"},
        "email": "[email protected]",
        "urls": ["https://example.com", "https://example.org"],
        "ips": ["192.168.1.1", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"],
    },
    "other_key": "other_value",
}

try:
    validated_data = validate_data(good_data)
    print("Validated data:", validated_data)
except ValidationError as e:
    print("Validation error:", e)

bad_data = {
    "user": {
        "name": "John",
        "age": "ggggg",
        "address": {"street": "Main St.", "city": 123, "zipcode": "123456"},
        "email": "john.doe@invalid",
        "urls": ["invalid-url"],
        "ips": ["999.999"],
    },
}

try:
    validated_data = validate_data(bad_data)
    print("Validated data:", validated_data)
except ValidationError as e:
    print("Validation error:", e)

ppigazzini avatar Jan 07 '25 13:01 ppigazzini

So pydantic seems to be somewhat similar to vtjson... :)

vdbergh avatar Jan 07 '25 15:01 vdbergh