Skip to content

API Reference

Auto-generated reference for the grnti_text_classifier package via mkdocstrings. Each module below lists its public classes and functions with their docstrings and type signatures.

Package root

grnti_text_classifier

Production-grade Russian multi-class text classifier (GRNTI).

Data

datamodule

Lightning DataModule wrapping GRNTIDataset for HuggingFace tokenizers.

Classes

GRNTIDataset
GRNTIDataset(df: DataFrame, tokenizer, max_length: int = 256)

Bases: Dataset

Map-style dataset over a processed GRNTI DataFrame.

Source code in src/grnti_text_classifier/data/datamodule.py
19
20
21
22
23
def __init__(self, df: pd.DataFrame, tokenizer, max_length: int = 256) -> None:
    self.texts: list[str] = df[TEXT_COL].tolist()
    self.labels = df[ENCODED_COL].to_numpy()
    self.tokenizer = tokenizer
    self.max_length = max_length
GRNTIDataModule
GRNTIDataModule(
    processed_dir: str | Path,
    model_name: str,
    batch_size: int = 16,
    max_length: int = 256,
    num_workers: int = 0,
    seed: int = 42,
)

Bases: LightningDataModule

LightningDataModule that loads train/val/test parquet splits.

Source code in src/grnti_text_classifier/data/datamodule.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    processed_dir: str | Path,
    model_name: str,
    batch_size: int = 16,
    max_length: int = 256,
    num_workers: int = 0,
    seed: int = 42,
) -> None:
    super().__init__()
    self.processed_dir = Path(processed_dir)
    self.model_name = model_name
    self.batch_size = batch_size
    self.max_length = max_length
    self.num_workers = num_workers
    self.seed = seed
    self._tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    self.train_df: pd.DataFrame | None = None
    self.val_df: pd.DataFrame | None = None
    self.test_df: pd.DataFrame | None = None

dataset

Dataset implementations.

Classes

TextDataset
TextDataset(
    csv_path: Path | str,
    text_col: str = "text",
    label_col: str = "label",
    tokenizer: Callable[..., Any] | None = None,
    max_length: int = 512,
)

Bases: Dataset[dict[str, Any]]

CSV-backed text classification dataset.

Source code in src/grnti_text_classifier/data/dataset.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(
    self,
    csv_path: Path | str,
    text_col: str = "text",
    label_col: str = "label",
    tokenizer: Callable[..., Any] | None = None,
    max_length: int = 512,
) -> None:
    import pandas as pd

    self.df = pd.read_csv(csv_path)
    self.text_col = text_col
    self.label_col = label_col
    self.tokenizer = tokenizer
    self.max_length = max_length

grnti

GRNTI dataset helpers: loader, label encoder, stratified split.

Classes

LabelEncoder dataclass
LabelEncoder(
    code_to_idx: dict[int, int],
    idx_to_code: dict[int, int],
    idx_to_text: dict[int, str],
    num_classes: int,
)

Bidirectional map between raw GRNTI codes and dense 0..N-1 indices.

Functions
encode
encode(labels: Series | list[int]) -> np.ndarray

Map raw codes → dense indices.

Source code in src/grnti_text_classifier/data/grnti.py
133
134
135
def encode(self, labels: pd.Series | list[int]) -> np.ndarray:
    """Map raw codes → dense indices."""
    return np.array([self.code_to_idx[int(c)] for c in labels], dtype=np.int64)
decode
decode(idx: int) -> int

Map dense index → raw code.

Source code in src/grnti_text_classifier/data/grnti.py
137
138
139
def decode(self, idx: int) -> int:
    """Map dense index → raw code."""
    return self.idx_to_code[int(idx)]
decode_text
decode_text(idx: int) -> str

Map dense index → human-readable Russian class name.

Source code in src/grnti_text_classifier/data/grnti.py
141
142
143
def decode_text(self, idx: int) -> str:
    """Map dense index → human-readable Russian class name."""
    return self.idx_to_text[int(idx)]
to_json_dict
to_json_dict() -> dict[str, Any]

Return a plain JSON-serialisable dict.

Source code in src/grnti_text_classifier/data/grnti.py
149
150
151
152
153
154
155
156
def to_json_dict(self) -> dict[str, Any]:
    """Return a plain JSON-serialisable dict."""
    return {
        "code_to_idx": {str(k): v for k, v in self.code_to_idx.items()},
        "idx_to_code": {str(k): v for k, v in self.idx_to_code.items()},
        "idx_to_text": {str(k): v for k, v in self.idx_to_text.items()},
        "num_classes": self.num_classes,
    }
from_json_dict classmethod
from_json_dict(d: dict[str, Any]) -> LabelEncoder

Reconstruct a LabelEncoder from a JSON dict.

Source code in src/grnti_text_classifier/data/grnti.py
158
159
160
161
162
163
164
165
166
@classmethod
def from_json_dict(cls, d: dict[str, Any]) -> LabelEncoder:
    """Reconstruct a LabelEncoder from a JSON dict."""
    return cls(
        code_to_idx={int(k): int(v) for k, v in d["code_to_idx"].items()},
        idx_to_code={int(k): int(v) for k, v in d["idx_to_code"].items()},
        idx_to_text={int(k): str(v) for k, v in d["idx_to_text"].items()},
        num_classes=int(d["num_classes"]),
    )

Functions

load_jsonl
load_jsonl(path: str | Path) -> pd.DataFrame

Read a JSONL file and return a DataFrame with FEATURES columns only.

Source code in src/grnti_text_classifier/data/grnti.py
107
108
109
110
111
112
def load_jsonl(path: str | Path) -> pd.DataFrame:
    """Read a JSONL file and return a DataFrame with FEATURES columns only."""
    df = pd.read_json(path, lines=True)
    # Keep only known FEATURES columns; drop unexpected extras defensively.
    cols = [c for c in FEATURES if c in df.columns]
    return df[cols]
build_label_encoder
build_label_encoder(df: DataFrame) -> LabelEncoder

Build a LabelEncoder from unique codes in df[LABEL_COL].

Codes are sorted ascending; idx = position in sorted order.

Source code in src/grnti_text_classifier/data/grnti.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def build_label_encoder(df: pd.DataFrame) -> LabelEncoder:
    """Build a LabelEncoder from unique codes in *df[LABEL_COL]*.

    Codes are sorted ascending; idx = position in sorted order.
    """
    codes: list[int] = sorted(int(c) for c in df[LABEL_COL].unique())
    code_to_idx: dict[int, int] = {c: i for i, c in enumerate(codes)}
    idx_to_code: dict[int, int] = {i: c for i, c in enumerate(codes)}
    idx_to_text: dict[int, str] = {i: _code_to_text(c) for i, c in enumerate(codes)}
    return LabelEncoder(
        code_to_idx=code_to_idx,
        idx_to_code=idx_to_code,
        idx_to_text=idx_to_text,
        num_classes=len(codes),
    )
split_stratified_train_val
split_stratified_train_val(
    df: DataFrame, *, val_fraction: float = 0.15, seed: int = 42
) -> tuple[pd.DataFrame, pd.DataFrame]

Split df into train and val subsets with stratification on LABEL_COL.

Source code in src/grnti_text_classifier/data/grnti.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def split_stratified_train_val(
    df: pd.DataFrame,
    *,
    val_fraction: float = 0.15,
    seed: int = 42,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Split *df* into train and val subsets with stratification on LABEL_COL."""
    train_df, val_df = train_test_split(
        df,
        test_size=val_fraction,
        stratify=df[LABEL_COL],
        random_state=seed,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True)

prepare

Data preparation CLI: raw JSONL → processed Parquet + label_encoder.json.

Functions

prepare_data
prepare_data(
    raw_dir: str | Path,
    out_dir: str | Path,
    *,
    val_fraction: float = 0.15,
    seed: int = 42,
) -> None

Load raw JSONL splits, build encoder, write Parquet + JSON artefacts.

Source code in src/grnti_text_classifier/data/prepare.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def prepare_data(
    raw_dir: str | Path,
    out_dir: str | Path,
    *,
    val_fraction: float = 0.15,
    seed: int = 42,
) -> None:
    """Load raw JSONL splits, build encoder, write Parquet + JSON artefacts."""
    raw = Path(raw_dir)
    out = Path(out_dir)
    out.mkdir(parents=True, exist_ok=True)

    train_raw = load_jsonl(raw / "train.jsonl")
    test_df = load_jsonl(raw / "test.jsonl")

    # Build encoder from all codes present in both splits combined.
    combined = pd.concat([train_raw, test_df], ignore_index=True)
    encoder = build_label_encoder(combined)

    # Stratified train / val split.
    train_df, val_df = split_stratified_train_val(train_raw, val_fraction=val_fraction, seed=seed)

    # Add dense label index to each split.
    train_df = train_df.copy()
    val_df = val_df.copy()
    test_df = test_df.copy()
    train_df[ENCODED_COL] = encoder.encode(train_df["label"])
    val_df[ENCODED_COL] = encoder.encode(val_df["label"])
    test_df[ENCODED_COL] = encoder.encode(test_df["label"])

    # Write Parquet files.
    train_df.to_parquet(out / "train.parquet", index=False)
    val_df.to_parquet(out / "val.parquet", index=False)
    test_df.to_parquet(out / "test.parquet", index=False)

    # Write label encoder JSON (ensure_ascii=False for readable Cyrillic).
    encoder_path = out / "label_encoder.json"
    encoder_path.write_text(
        json.dumps(encoder.to_json_dict(), ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    print(
        f"[prepare] train={len(train_df)} val={len(val_df)} test={len(test_df)}"
        f" classes={encoder.num_classes}"
    )

transforms

Image transforms for training and inference.

Models

factory

Model factories — return a pretrained HuggingFace model ready for fine-tuning.

Functions

build_main
build_main(num_labels: int = 28) -> PreTrainedModel

Return XLM-RoBERTa-base configured for sequence classification.

Parameters:

Name Type Description Default
num_labels int

Number of output classes (default 28 for GRNTI).

28

Returns:

Type Description
PreTrainedModel

AutoModelForSequenceClassification instance.

Source code in src/grnti_text_classifier/models/factory.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
def build_main(num_labels: int = 28) -> PreTrainedModel:
    """Return XLM-RoBERTa-base configured for sequence classification.

    Args:
        num_labels: Number of output classes (default 28 for GRNTI).

    Returns:
        AutoModelForSequenceClassification instance.
    """
    return AutoModelForSequenceClassification.from_pretrained(
        "FacebookAI/xlm-roberta-base",
        num_labels=num_labels,
        ignore_mismatched_sizes=True,
    )
build_baseline
build_baseline(num_labels: int = 28) -> PreTrainedModel

Return ruBERT-base-cased configured for sequence classification.

Parameters:

Name Type Description Default
num_labels int

Number of output classes (default 28 for GRNTI).

28

Returns:

Type Description
PreTrainedModel

AutoModelForSequenceClassification instance.

Source code in src/grnti_text_classifier/models/factory.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def build_baseline(num_labels: int = 28) -> PreTrainedModel:
    """Return ruBERT-base-cased configured for sequence classification.

    Args:
        num_labels: Number of output classes (default 28 for GRNTI).

    Returns:
        AutoModelForSequenceClassification instance.
    """
    return AutoModelForSequenceClassification.from_pretrained(
        "DeepPavlov/rubert-base-cased",
        num_labels=num_labels,
        ignore_mismatched_sizes=True,
    )

lightning_module

Lightning wrapper for GRNTI sequence classification models.

Classes

GRNTIClassifier
GRNTIClassifier(
    model: PreTrainedModel,
    class_weights: Tensor | None = None,
    lr: float = 2e-05,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.1,
    total_steps: int = 1000,
    num_classes: int = 28,
)

Bases: LightningModule

Lightning module wrapping any HuggingFace sequence-classification model.

Source code in src/grnti_text_classifier/models/lightning_module.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    model: PreTrainedModel,
    class_weights: torch.Tensor | None = None,
    lr: float = 2e-5,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.1,
    total_steps: int = 1000,
    num_classes: int = 28,
) -> None:
    super().__init__()
    self.model = model
    self.class_weights = class_weights
    self.save_hyperparameters(ignore=["model", "class_weights"])

    self.train_f1 = MulticlassF1Score(num_classes=num_classes, average="macro")
    self.val_f1 = MulticlassF1Score(num_classes=num_classes, average="macro")
    self.test_f1 = MulticlassF1Score(num_classes=num_classes, average="macro")

    _top5 = min(5, num_classes)
    self.val_top1 = MulticlassAccuracy(num_classes=num_classes, top_k=1, average="micro")
    self.val_top5 = MulticlassAccuracy(num_classes=num_classes, top_k=_top5, average="micro")
    self.test_top1 = MulticlassAccuracy(num_classes=num_classes, top_k=1, average="micro")
    self.test_top5 = MulticlassAccuracy(num_classes=num_classes, top_k=_top5, average="micro")

Training

optuna_sweep

Optuna hyper-parameter sweep over train_one for GRNTI classifiers.

Functions

run_sweep
run_sweep(
    processed_dir: Path,
    out_dir: Path,
    *,
    model_builder: Callable[..., Any],
    model_name_for_tokenizer: str,
    n_trials: int = 10,
    seed: int = 42,
    trial_epochs: int = 3,
    batch_size: int = 16,
    num_workers: int = 0,
) -> dict[str, Any]

Run an Optuna TPE sweep over learning-rate, weight-decay and warmup-ratio.

Parameters

processed_dir: Pre-processed data directory (parquet splits + label_encoder.json). out_dir: Root directory; each trial writes to out_dir / "trial_<n>". model_builder: Callable (num_labels: int) -> PreTrainedModel. model_name_for_tokenizer: HuggingFace model name used to build the tokeniser. n_trials: Number of Optuna trials. seed: Seed for the TPE sampler (for reproducibility). trial_epochs: Max epochs per trial (use a small number for speed). batch_size: Mini-batch size forwarded to train_one. num_workers: DataLoader workers forwarded to train_one.

Returns

dict {"best_params": {...}, "best_value": float}

Source code in src/grnti_text_classifier/training/optuna_sweep.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def run_sweep(
    processed_dir: Path,
    out_dir: Path,
    *,
    model_builder: Callable[..., Any],
    model_name_for_tokenizer: str,
    n_trials: int = 10,
    seed: int = 42,
    trial_epochs: int = 3,
    batch_size: int = 16,
    num_workers: int = 0,
) -> dict[str, Any]:
    """Run an Optuna TPE sweep over learning-rate, weight-decay and warmup-ratio.

    Parameters
    ----------
    processed_dir:
        Pre-processed data directory (parquet splits + label_encoder.json).
    out_dir:
        Root directory; each trial writes to ``out_dir / "trial_<n>"``.
    model_builder:
        Callable ``(num_labels: int) -> PreTrainedModel``.
    model_name_for_tokenizer:
        HuggingFace model name used to build the tokeniser.
    n_trials:
        Number of Optuna trials.
    seed:
        Seed for the TPE sampler (for reproducibility).
    trial_epochs:
        Max epochs per trial (use a small number for speed).
    batch_size:
        Mini-batch size forwarded to ``train_one``.
    num_workers:
        DataLoader workers forwarded to ``train_one``.

    Returns
    -------
    dict
        ``{"best_params": {...}, "best_value": float}``
    """
    processed_dir = Path(processed_dir)
    out_dir = Path(out_dir)

    def objective(trial: optuna.Trial) -> float:
        lr = trial.suggest_float("lr", 1e-5, 5e-5, log=True)
        weight_decay = trial.suggest_float("weight_decay", 0.01, 0.1, log=False)
        warmup_ratio = trial.suggest_float("warmup_ratio", 0.05, 0.15, log=False)

        trial_dir = out_dir / f"trial_{trial.number}"

        train_one(
            model_builder,
            model_name_for_tokenizer,
            processed_dir,
            trial_dir,
            max_epochs=trial_epochs,
            batch_size=batch_size,
            lr=lr,
            weight_decay=weight_decay,
            warmup_ratio=warmup_ratio,
            patience=1,
            num_workers=num_workers,
            save_hf=False,
        )

        # Read the best val/macro_f1 from the CSVLogger output
        metrics_csv = trial_dir / "logs" / "version_0" / "metrics.csv"
        import pandas as pd

        df = pd.read_csv(metrics_csv)
        if "val/macro_f1" not in df.columns:
            best_f1 = 0.0
        else:
            best_f1 = float(df["val/macro_f1"].dropna().max())

        import gc

        import torch

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
        return best_f1

    study = optuna.create_study(
        direction="maximize",
        sampler=TPESampler(seed=seed),
    )
    study.optimize(objective, n_trials=n_trials)

    return {"best_params": study.best_params, "best_value": float(study.best_value)}

train

Lightning Trainer entrypoint for GRNTI text classification.

Classes

Functions

train_one
train_one(
    model_builder: Callable[..., Any],
    model_name_for_tokenizer: str,
    processed_dir: Path,
    out_dir: Path,
    *,
    max_epochs: int = 5,
    batch_size: int = 16,
    lr: float = 2e-05,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.1,
    patience: int = 2,
    seed: int = 42,
    max_length: int = 256,
    num_workers: int = 0,
    save_hf: bool = True,
) -> Path

Train a single GRNTI classifier run.

Parameters

model_builder: Callable (num_labels: int) -> PreTrainedModel. model_name_for_tokenizer: HuggingFace model name used to load the tokenizer, e.g. "FacebookAI/xlm-roberta-base". processed_dir: Directory containing train.parquet, val.parquet, test.parquet, and label_encoder.json. out_dir: Root output directory for this run. max_epochs: Maximum training epochs. batch_size: Batch size for training and validation. lr: Peak learning rate for AdamW. weight_decay: AdamW weight-decay coefficient. warmup_ratio: Fraction of total steps used for linear warmup. patience: Early-stopping patience (in validation epochs). seed: Global random seed. max_length: Tokeniser max-sequence length. num_workers: DataLoader worker count. save_hf: When True (default) saves the best checkpoint as a HuggingFace model directory and returns that path. When False, skips the HF save and returns the raw checkpoint path instead (useful for sweeps).

Returns

Path out_dir / "hf" when save_hf is True; otherwise the path to the best Lightning checkpoint file.

Source code in src/grnti_text_classifier/training/train.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def train_one(
    model_builder: Callable[..., Any],
    model_name_for_tokenizer: str,
    processed_dir: Path,
    out_dir: Path,
    *,
    max_epochs: int = 5,
    batch_size: int = 16,
    lr: float = 2e-5,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.1,
    patience: int = 2,
    seed: int = 42,
    max_length: int = 256,
    num_workers: int = 0,
    save_hf: bool = True,
) -> Path:
    """Train a single GRNTI classifier run.

    Parameters
    ----------
    model_builder:
        Callable ``(num_labels: int) -> PreTrainedModel``.
    model_name_for_tokenizer:
        HuggingFace model name used to load the tokenizer, e.g.
        ``"FacebookAI/xlm-roberta-base"``.
    processed_dir:
        Directory containing ``train.parquet``, ``val.parquet``,
        ``test.parquet``, and ``label_encoder.json``.
    out_dir:
        Root output directory for this run.
    max_epochs:
        Maximum training epochs.
    batch_size:
        Batch size for training and validation.
    lr:
        Peak learning rate for AdamW.
    weight_decay:
        AdamW weight-decay coefficient.
    warmup_ratio:
        Fraction of total steps used for linear warmup.
    patience:
        Early-stopping patience (in validation epochs).
    seed:
        Global random seed.
    max_length:
        Tokeniser max-sequence length.
    num_workers:
        DataLoader worker count.
    save_hf:
        When ``True`` (default) saves the best checkpoint as a HuggingFace
        model directory and returns that path.  When ``False``, skips the HF
        save and returns the raw checkpoint path instead (useful for sweeps).

    Returns
    -------
    Path
        ``out_dir / "hf"`` when *save_hf* is True; otherwise the path to the
        best Lightning checkpoint file.
    """
    processed_dir = Path(processed_dir)
    out_dir = Path(out_dir)

    # 1. Global seed
    L.seed_everything(seed, workers=True)

    # 2. DataModule
    dm = GRNTIDataModule(
        processed_dir,
        model_name_for_tokenizer,
        batch_size=batch_size,
        max_length=max_length,
        num_workers=num_workers,
        seed=seed,
    )
    dm.setup()

    # 3. Inverse-frequency class weights
    label_enc = json.loads((processed_dir / "label_encoder.json").read_text(encoding="utf-8"))
    num_classes: int = int(label_enc["num_classes"])

    import pandas as pd

    train_df = pd.read_parquet(processed_dir / "train.parquet")
    freq = np.bincount(train_df["label_idx"].to_numpy(), minlength=num_classes).astype(np.float64)
    weights = 1.0 / np.clip(freq, 1, None)
    weights = weights / weights.mean()
    class_weights = torch.tensor(weights, dtype=torch.float32)

    # 4. Build model + Lightning module
    inner = model_builder(num_labels=num_classes)
    steps_per_epoch = len(dm.train_dataloader())
    total_steps = steps_per_epoch * max_epochs

    lit = GRNTIClassifier(
        inner,
        class_weights=class_weights,
        lr=lr,
        weight_decay=weight_decay,
        warmup_ratio=warmup_ratio,
        total_steps=total_steps,
        num_classes=num_classes,
    )

    # 5. Hardware
    precision: str | int = "bf16-mixed" if torch.cuda.is_available() else 32
    accelerator = "gpu" if torch.cuda.is_available() else "cpu"

    # 6. Callbacks + logger
    ckpt_cb = ModelCheckpoint(
        dirpath=out_dir / "ckpt",
        filename="{epoch:02d}-{val_macro_f1:.4f}",
        monitor="val/macro_f1",
        mode="max",
        save_top_k=1,
        save_last=False,
    )
    es_cb = EarlyStopping(monitor="val/macro_f1", mode="max", patience=patience)
    logger = CSVLogger(save_dir=str(out_dir), name="logs")

    # 7. Trainer
    trainer = L.Trainer(
        accelerator=accelerator,
        devices=1,
        precision=precision,  # type: ignore[arg-type]  # template stub; revisit in backport
        max_epochs=max_epochs,
        callbacks=[ckpt_cb, es_cb],
        logger=logger,
        deterministic="warn",
        enable_progress_bar=True,
        log_every_n_steps=max(1, steps_per_epoch // 10),
    )
    trainer.fit(lit, datamodule=dm)

    # 8. Optionally reload best checkpoint and export to HF format
    if not save_hf:
        return Path(ckpt_cb.best_model_path)

    best = GRNTIClassifier.load_from_checkpoint(
        ckpt_cb.best_model_path,
        model=model_builder(num_labels=num_classes),
        class_weights=None,
    )

    hf_dir = out_dir / "hf"
    hf_dir.mkdir(parents=True, exist_ok=True)
    best.model.save_pretrained(hf_dir)
    dm._tok.save_pretrained(hf_dir)
    return hf_dir

Evaluation

confusion

Confusion matrix visualisation — saves a seaborn heatmap PNG.

Functions

save_confusion_matrix
save_confusion_matrix(
    y_true: ndarray, preds: ndarray, labels: list[str], out_path: str | Path
) -> None

Save a row-normalised confusion matrix heatmap to out_path (PNG).

Parameters

y_true: Ground-truth integer labels, shape (n,). preds: Predicted integer labels, shape (n,). labels: Human-readable class names (e.g. ["Математика", "Информатика"]). Length must equal the number of classes. out_path: Destination file path. Parent directories are created if absent.

Source code in src/grnti_text_classifier/evaluation/confusion.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def save_confusion_matrix(
    y_true: np.ndarray,
    preds: np.ndarray,
    labels: list[str],
    out_path: str | Path,
) -> None:
    """Save a row-normalised confusion matrix heatmap to *out_path* (PNG).

    Parameters
    ----------
    y_true:
        Ground-truth integer labels, shape ``(n,)``.
    preds:
        Predicted integer labels, shape ``(n,)``.
    labels:
        Human-readable class names (e.g. ``["Математика", "Информатика"]``).
        Length must equal the number of classes.
    out_path:
        Destination file path.  Parent directories are created if absent.
    """
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    cm = confusion_matrix(
        y_true,
        preds,
        labels=list(range(len(labels))),
        normalize="true",
    )

    _fig, ax = plt.subplots(figsize=(14, 12))
    sns.heatmap(
        cm,
        annot=False,
        cmap="Blues",
        xticklabels=labels,
        yticklabels=labels,
        ax=ax,
    )
    ax.set_title("Confusion matrix (main model, row-normalised)")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close("all")

evaluate

CLI: score a saved HF checkpoint on a processed parquet split.

Functions

main
main() -> None

Entry point for the scoring CLI.

Source code in src/grnti_text_classifier/evaluation/evaluate.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def main() -> None:
    """Entry point for the scoring CLI."""
    parser = argparse.ArgumentParser(
        description="Score a saved HF checkpoint on a processed parquet split."
    )
    parser.add_argument("--hf-dir", required=True, help="HF model directory from train_one")
    parser.add_argument(
        "--split", required=True, help="Parquet file path (e.g. data/processed/test.parquet)"
    )
    parser.add_argument(
        "--label-encoder",
        required=True,
        help="Label encoder JSON (e.g. data/processed/label_encoder.json)",
    )
    parser.add_argument("--out", required=True, help="Output metrics JSON path")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--max-length", type=int, default=256)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(args.hf_dir)
    model = AutoModelForSequenceClassification.from_pretrained(args.hf_dir)
    model.to(device)
    model.train(False)

    df = pd.read_parquet(args.split)
    with open(args.label_encoder, encoding="utf-8") as fh:
        encoder = json.load(fh)
    num_classes = len(encoder)

    texts = df["text"].tolist()
    all_logits: list[np.ndarray] = []

    for start in range(0, len(texts), args.batch_size):
        batch_texts = texts[start : start + args.batch_size]
        inputs = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_length,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.inference_mode():
            out = model(**inputs)
        all_logits.append(out.logits.cpu().numpy())

    logits = np.concatenate(all_logits, axis=0)
    y_true = df["label_idx"].to_numpy()

    metrics = compute_metrics(y_true, logits, num_classes=num_classes)

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")

    print(metrics)

metrics

Metrics computation for classification scoring.

Functions

compute_metrics
compute_metrics(
    y_true: ndarray, logits: ndarray | object, num_classes: int
) -> dict[str, Any]

Return top-1/top-5 accuracy, macro/weighted F1, num_classes, and n.

Parameters

y_true: Integer class indices, shape (n,). logits: Raw model outputs, shape (n, num_classes). Accepts either a NumPy array or a torch.Tensor — tensors are converted to NumPy automatically. num_classes: Total number of label classes.

Returns

dict with keys: top1_accuracy, top5_accuracy, macro_f1, weighted_f1, num_classes, n.

Source code in src/grnti_text_classifier/evaluation/metrics.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def compute_metrics(
    y_true: np.ndarray,
    logits: np.ndarray | object,
    num_classes: int,
) -> dict[str, Any]:
    """Return top-1/top-5 accuracy, macro/weighted F1, num_classes, and n.

    Parameters
    ----------
    y_true:
        Integer class indices, shape ``(n,)``.
    logits:
        Raw model outputs, shape ``(n, num_classes)``.  Accepts either a
        NumPy array or a torch.Tensor — tensors are converted to NumPy
        automatically.
    num_classes:
        Total number of label classes.

    Returns
    -------
    dict with keys: top1_accuracy, top5_accuracy, macro_f1, weighted_f1,
    num_classes, n.
    """
    # Accept torch tensors without importing torch at module level.
    if hasattr(logits, "cpu"):
        logits = logits.cpu().numpy()
    logits = np.asarray(logits, dtype=np.float32)

    labels = list(range(num_classes))
    preds = logits.argmax(axis=-1)

    top1 = float(top_k_accuracy_score(y_true, logits, k=1, labels=labels))
    top5 = float(top_k_accuracy_score(y_true, logits, k=min(5, num_classes), labels=labels))
    macro_f1 = float(f1_score(y_true, preds, average="macro", zero_division=0))
    weighted_f1 = float(f1_score(y_true, preds, average="weighted", zero_division=0))

    return {
        "top1_accuracy": top1,
        "top5_accuracy": top5,
        "macro_f1": macro_f1,
        "weighted_f1": weighted_f1,
        "num_classes": int(num_classes),
        "n": len(y_true),
    }

report

Summary report builder — merges main and baseline metrics into a JSON file.

Functions

build_summary
build_summary(
    main_metrics: dict[str, Any],
    baseline_metrics: dict[str, Any],
    *,
    out_path: str | Path,
) -> dict[str, Any]

Write a flat JSON summary combining main and baseline scoring results.

Parameters

main_metrics: Output of compute_metrics for the primary model. baseline_metrics: Output of compute_metrics for the baseline model. out_path: Destination path for the JSON file. Parent dirs are created if needed.

Returns

The summary dict that was written to disk.

Source code in src/grnti_text_classifier/evaluation/report.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def build_summary(
    main_metrics: dict[str, Any],
    baseline_metrics: dict[str, Any],
    *,
    out_path: str | Path,
) -> dict[str, Any]:
    """Write a flat JSON summary combining main and baseline scoring results.

    Parameters
    ----------
    main_metrics:
        Output of ``compute_metrics`` for the primary model.
    baseline_metrics:
        Output of ``compute_metrics`` for the baseline model.
    out_path:
        Destination path for the JSON file.  Parent dirs are created if needed.

    Returns
    -------
    The summary dict that was written to disk.
    """
    assert main_metrics["n"] == baseline_metrics["n"], (
        f"Test-set sizes differ: {main_metrics['n']} vs {baseline_metrics['n']}"
    )
    assert main_metrics["num_classes"] == baseline_metrics["num_classes"], (
        "num_classes mismatch between main and baseline"
    )

    def pct(v: float) -> str:
        return f"{v:.1%}"

    summary = {
        "main_model": "FacebookAI/xlm-roberta-base",
        "baseline_model": "DeepPavlov/rubert-base-cased",
        "main_top1": pct(main_metrics["top1_accuracy"]),
        "main_top5": pct(main_metrics["top5_accuracy"]),
        "main_macro_f1": pct(main_metrics["macro_f1"]),
        "main_weighted_f1": pct(main_metrics["weighted_f1"]),
        "baseline_top1": pct(baseline_metrics["top1_accuracy"]),
        "baseline_top5": pct(baseline_metrics["top5_accuracy"]),
        "baseline_macro_f1": pct(baseline_metrics["macro_f1"]),
        "baseline_weighted_f1": pct(baseline_metrics["weighted_f1"]),
        "test_size": main_metrics["n"],
        "num_classes": main_metrics["num_classes"],
    }

    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    return summary

Inference

predict

Inference CLI — load a checkpoint and predict on input(s).

Functions

load_model
load_model(checkpoint_path: str | Path) -> Any

Load a Lightning module from checkpoint, rebuilding the backbone from hparams.

Source code in src/grnti_text_classifier/inference/predict.py
15
16
17
18
19
20
21
22
23
24
25
def load_model(checkpoint_path: str | Path) -> Any:
    """Load a Lightning module from checkpoint, rebuilding the backbone from hparams."""
    import torch

    from ..models import GRNTIClassifier, build_main

    ckpt = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
    hp = ckpt.get("hyper_parameters", {})
    num_labels = hp.get("num_labels", 28)
    backbone = build_main(num_labels=int(num_labels))
    return GRNTIClassifier.load_from_checkpoint(str(checkpoint_path), model=backbone)
predict
predict(model: Any, input_path: str | Path) -> dict[str, Any]

Run a single prediction. Returns a task-specific result dict.

Source code in src/grnti_text_classifier/inference/predict.py
28
29
30
def predict(model: Any, input_path: str | Path) -> dict[str, Any]:
    """Run a single prediction. Returns a task-specific result dict."""
    raise NotImplementedError("Override predict() per project")

Serving

dependencies

Dependency injection — singleton model loader.

Functions

errors

Exception types and handlers.

main

FastAPI application.

routes

GRNTI classifier routes — /health, /classify, /labels.

Classes

schemas

Pydantic request/response schemas for the /classify endpoint.

Classes

TextPayload

Bases: BaseModel

Request body for text classification — raw abstract + optional token budget.

LabelProb

Bases: BaseModel

GRNTI class identifier together with its human-readable name and probability.

ClassificationResponse

Bases: BaseModel

Response payload of /classify: top-1 plus top-5 probabilities and metadata.

LabelEntry

Bases: BaseModel

Label catalog entry returned by /labels — numeric id plus human-readable name.

Utilities

hf_hub

HuggingFace Hub helpers.

logging

Structured logging configuration.

seed

Deterministic seeding across libraries.