Skip to content

API Reference

Auto-generated reference for the chest_xray_classifier package via mkdocstrings. Each module below lists its public classes and functions with their docstrings and type signatures.

Package root

chest_xray_classifier

Production-grade 3-class chest X-ray classifier.

Data

datamodule

Lightning DataModule.

Classes

dataset

Dataset implementations.

Classes

ImageDataset
ImageDataset(
    root: Path | str,
    transform: Callable | None = None,
    extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".tif", ".tiff"),
)

Bases: Dataset

Generic image dataset with class-subdir layout.

Source code in src/chest_xray_classifier/data/dataset.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(
    self,
    root: Path | str,
    transform: Callable | None = None,
    extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".tif", ".tiff"),
) -> None:
    self.root = Path(root)
    self.transform = transform
    self.samples: list[tuple[Path, int]] = []
    classes = sorted(p.name for p in self.root.iterdir() if p.is_dir())
    self.class_to_idx = {c: i for i, c in enumerate(classes)}
    for cls, idx in self.class_to_idx.items():
        for ext in extensions:
            self.samples.extend((p, idx) for p in (self.root / cls).glob(f"**/*{ext}"))

prepare

Split Kaggle chest X-ray dataset into 3-class layout.

transforms

Image transforms for training and inference.

Models

baseline_dinov2

DINOv2 ViT-S feature extractor + linear classification head.

factory

Model factory — returns a torch.nn.Module by name.

lightning_module

Lightning module wrappers.

Training

train

Training entrypoint (Hydra-powered).

Evaluation

evaluate

Run model on test set, write reports/metrics.json.

Classes

Functions

Inference

predict

Inference CLI — load a checkpoint and predict on input(s).

Functions

load_model
load_model(checkpoint_path: str | Path)

Load a Lightning module from checkpoint, rebuilding the backbone from hparams.

Source code in src/chest_xray_classifier/inference/predict.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def load_model(checkpoint_path: str | Path):
    """Load a Lightning module from checkpoint, rebuilding the backbone from hparams."""
    import torch

    from ..models import ClassificationModule, build_model

    ckpt = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
    hp = ckpt.get("hyper_parameters", {})
    model_name = hp.get("model_name")
    num_classes = hp.get("num_classes")
    if model_name is None or num_classes is None:
        raise ValueError(
            "Checkpoint missing model_name/num_classes hparams — "
            "re-train after upgrading ClassificationModule."
        )
    backbone = build_model(model_name, num_classes=num_classes, pretrained=False)
    return ClassificationModule.load_from_checkpoint(str(checkpoint_path), model=backbone)
predict
predict(model, input_path: str | Path)

Run a single prediction. Returns a task-specific result dict.

Source code in src/chest_xray_classifier/inference/predict.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def predict(model, input_path: str | Path):
    """Run a single prediction. Returns a task-specific result dict."""
    import torch
    from PIL import Image

    from ..data.transforms import build_eval_transforms

    model.eval()
    tf = build_eval_transforms()
    img = Image.open(input_path).convert("RGB")
    x = tf(img).unsqueeze(0)
    with torch.no_grad():
        logits = model._forward_logits(x) if hasattr(model, "_forward_logits") else model(x)
        probs = logits.softmax(-1).squeeze(0).tolist()
    return {"probs": probs, "pred": int(max(range(len(probs)), key=probs.__getitem__))}

Serving

dependencies

Dependency injection — singleton model loader.

Functions

errors

Exception types and handlers.

Classes

ModelNotLoadedError

Bases: RuntimeError

Raised when the checkpoint has not been loaded into the running app yet.

InferenceError

Bases: RuntimeError

Raised on irrecoverable errors during forward-pass (corrupt image, CUDA OOM).

Functions

inference_error_handler async
inference_error_handler(request: Request, exc: InferenceError) -> JSONResponse

FastAPI handler: map InferenceError to a 503 with a structured payload.

Source code in src/chest_xray_classifier/serving/errors.py
17
18
19
20
21
22
23
24
25
26
async def inference_error_handler(request: Request, exc: InferenceError) -> JSONResponse:
    """FastAPI handler: map `InferenceError` to a 503 with a structured payload."""
    return JSONResponse(
        status_code=503,
        content={
            "error": "inference_failed",
            "detail": str(exc),
            "request_id": getattr(request.state, "request_id", None),
        },
    )
model_not_loaded_handler async
model_not_loaded_handler(
    request: Request, exc: ModelNotLoadedError
) -> JSONResponse

FastAPI handler: map ModelNotLoadedError to a 503 model_not_ready response.

Source code in src/chest_xray_classifier/serving/errors.py
29
30
31
32
33
34
35
36
37
38
async def model_not_loaded_handler(request: Request, exc: ModelNotLoadedError) -> JSONResponse:
    """FastAPI handler: map `ModelNotLoadedError` to a 503 `model_not_ready` response."""
    return JSONResponse(
        status_code=503,
        content={
            "error": "model_not_ready",
            "detail": str(exc),
            "request_id": getattr(request.state, "request_id", None),
        },
    )

main

FastAPI application.

Classes

Functions

routes

FastAPI routes.

Classes

Functions

schemas

Pydantic request/response schemas.

Classes

HealthResponse

Bases: BaseModel

Response payload of /health — liveness plus whether the model is loaded.

PredictionResponse

Bases: BaseModel

Response payload of /predict — argmax class index plus full softmax probabilities.

Utilities

hf_hub

HuggingFace Hub helpers.

logging

Structured logging configuration.

seed

Deterministic seeding across libraries.