Skip to content

Data layer

Read side and write side of the data layer, both expressed as Protocols (mirroring the engine's substrate seam one layer up).

  • ReadCohortRepository produces a Cohort bundle; RGSCohortRepository is the production DB-backed implementation.
  • WritePrescriptionStore persists the recommendation output and answers the idempotency question; RGSPrescriptionStore is production.
  • Clinical mappersClinicalSubscales, ProtocolToClinicalMapper feed the offline PPF/similarity computation.

data

Data layer — Repository pattern.

Single entry point for everything the recommendation engine needs to read. Mirrors the EngineState Protocol pattern from engine.py one layer up:

CohortRepository  Protocol      (abstract source of cohorts)
    ↑ ↑
    │ └── SyntheticCohortRepository   (future — see SYNTHETIC_DATA_PLAN.md)
    └──── RGSCohortRepository       (production)
            ↓
          Cohort  dataclass            (typed bundle of frames)
            ↓
          DataPipeline.process(cohort) (downstream)

A Cohort is the complete set of data the pipeline needs for one recommendation call. The CohortRepository.find(patient_ids) call is the one-shot fetch. The pipeline + engine downstream don't care which repository implementation served the call.

This file is sectioned:

SECTION 1  File-IO primitives — read_yaml/csv/parquet helpers,
           subscale decoder, whitelist loader. Pure functions.
SECTION 2  Cohort dataclass — the typed bundle the engine consumes.
SECTION 3  CohortRepository Protocol — abstract source contract.
SECTION 4  RGSCohortRepository — production implementation
           (DB via rgs_interface + local Parquet/CSV).
SECTION 5  Clinical mappers — ClinicalSubscales +
           ProtocolToClinicalMapper. Used by the loader's
           specialized accessors for PPF/similarity computation.
SECTION 6  Write side — PrescriptionStore Protocol +
           RGSPrescriptionStore. Symmetric to CohortRepository:
           the engine reads a Cohort, the store writes the
           recommendation output back + answers idempotency.

Cohort dataclass

Cohort(
    patient: DataFrame,
    session: DataFrame,
    ppf: DataFrame,
    similarity: DataFrame,
    whitelist: List[int],
    missing_ppf: List[int],
)

One cohort's worth of data — the full bundle the pipeline + engine consume.

Attributes:

Name Type Description
patient DataFrame

One row per patient in the cohort. Anchors clinical window.

session DataFrame

One row per (patient, protocol, session). Raw observed sessions windowed by the loader's patient list, NOT yet date-clamped.

ppf DataFrame

One row per (patient, protocol) — the PPF cohort. Drives the engine's env-wide alternative set.

similarity DataFrame

Long-form (PROTOCOL_A, PROTOCOL_B, SIMILARITY). Already filtered to the whitelist on both sides.

whitelist list[int]

The allowed-protocols list applied to filter patient/session/ ppf/similarity. Here for audit/trace, not for re-filtering.

missing_ppf list[int]

Patient IDs that had no PPF rows on disk (the loader injects placeholder rows with PPF=None for these). Empty for healthy cohorts. Non-empty means callers should refuse to recommend until PPF is computed.

CohortRepository

Bases: Protocol

Read interface every cohort source must implement.

The minimum contract is find(patient_ids) -> Cohort. Concrete repositories MAY offer specialized accessors (patient_subscales, protocol_attributes) for offline workflows (PPF + similarity computation), but those are NOT part of the protocol — they live only on the concrete classes.

RGSCohortRepository

RGSCohortRepository(
    db: Optional[DatabaseInterface] = None,
    rgs_mode: str = "plus",
    whitelist: Optional[List[int]] = None,
)

Production CohortRepository. Single concrete implementation today; satisfies the protocol contract.

Source code in src\ai_cdss\data.py
def __init__(
    self,
    db: Optional[DatabaseInterface] = None,
    rgs_mode: str = "plus",
    whitelist: Optional[List[int]] = None,
) -> None:
    self.interface = db or DatabaseInterface()
    self.rgs_mode = rgs_mode
    self.whitelist = whitelist if whitelist is not None else load_whitelist()

find

find(patient_ids: List[int]) -> Cohort

One-shot fetch + filter + assemble. Returns the full Cohort bundle ready for the pipeline.

Source code in src\ai_cdss\data.py
def find(self, patient_ids: List[int]) -> Cohort:
    """One-shot fetch + filter + assemble. Returns the full Cohort
    bundle ready for the pipeline."""
    ppf = self._fetch(_load_ppf_data, patient_ids, name="ppf")
    missing_ppf = list(ppf.attrs.get("missing_patients", []))

    session = self._fetch(
        lambda p: self.interface.fetch_rgs_data(p, rgs_mode=self.rgs_mode),
        patient_ids, name="sessions",
    )
    patient = self._fetch(
        self.interface.fetch_clinical_data, patient_ids, name="patient",
    )
    similarity = self._fetch(
        lambda _: _load_protocol_similarity(), patient_ids, name="similarity",
    )

    logger.info("Loaded data for patients: %s", patient_ids)
    logger.info("Session data shape: %s", session.shape)
    logger.info("PPF data shape: %s", ppf.shape)

    # Apply whitelist filter to session / ppf / similarity.
    if self.whitelist:
        allowed = set(self.whitelist)
        if PROTOCOL_ID in session.columns:
            session = session[session[PROTOCOL_ID].isin(allowed)]
        if PROTOCOL_ID in ppf.columns:
            ppf = ppf[ppf[PROTOCOL_ID].isin(allowed)]
        # similarity is long-form — filter both sides of the pair.
        similarity = similarity[similarity[PROTOCOL_A].isin(allowed)]
        similarity = similarity[similarity[PROTOCOL_B].isin(allowed)]

    return Cohort(
        patient=patient,
        session=session,
        ppf=ppf,
        similarity=similarity,
        whitelist=list(self.whitelist),
        missing_ppf=missing_ppf,
    )

patient_subscales

patient_subscales(patient_ids: List[int]) -> DataFrame

Patient subscales from clinical_data.CLINICAL_SCORES (JSON-encoded). Decode the latest evaluation per patient, return as a flat DataFrame indexed by PATIENT_ID.

Source code in src\ai_cdss\data.py
def patient_subscales(self, patient_ids: List[int]) -> pd.DataFrame:
    """Patient subscales from `clinical_data.CLINICAL_SCORES`
    (JSON-encoded). Decode the latest evaluation per patient,
    return as a flat DataFrame indexed by PATIENT_ID."""
    patient = self._fetch(
        self.interface.fetch_clinical_data, patient_ids, name="patient",
    )
    decoded = patient.apply(decode_subscales, axis=1)
    return decoded.set_index(PATIENT_ID)

protocol_attributes

protocol_attributes(
    file_path: Optional[str] = None,
) -> DataFrame

Protocol attributes from local CSV (with embedded-package fallback).

Source code in src\ai_cdss\data.py
def protocol_attributes(self, file_path: Optional[str] = None) -> pd.DataFrame:
    """Protocol attributes from local CSV (with embedded-package
    fallback)."""
    return _load_protocol_attributes(file_path=file_path)

fetch_and_validate_patients

fetch_and_validate_patients(
    study_ids: List[int],
) -> List[int]

Patient IDs for one or more study cohorts. Returns [] (with a warning) when no patients are found — callers expect this empty-list contract.

Source code in src\ai_cdss\data.py
def fetch_and_validate_patients(self, study_ids: List[int]) -> List[int]:
    """Patient IDs for one or more study cohorts. Returns `[]` (with
    a warning) when no patients are found — callers expect this
    empty-list contract."""
    patient_data = self.interface.fetch_patients_by_study(study_ids=study_ids)
    if patient_data is None or patient_data.empty:
        logger.warning("No patients found for study IDs %s", study_ids)
        return []
    return patient_data[PATIENT_ID].tolist()

ClinicalSubscales

ClinicalSubscales(scale_yaml_path: Optional[str] = None)

Patient subscale-scores → deficit-matrix transformer.

Reads max-subscale values from a YAML config (default: the embedded config/scales.yaml). The deficit matrix is 1 - (score / max) so higher values mean larger deficits.

Source code in src\ai_cdss\data.py
def __init__(self, scale_yaml_path: Optional[str] = None) -> None:
    if scale_yaml_path:
        self.scales_path = Path(scale_yaml_path)
    else:
        self.scales_path = importlib.resources.files(config) / Path(SCALES_YAML)
    if not self.scales_path.exists():
        raise FileNotFoundError(f"Scale YAML file not found at {self.scales_path}")
    self.scales_dict = MultiKeyDict.from_yaml(self.scales_path)

compute_deficit_matrix

compute_deficit_matrix(patient_df: DataFrame) -> DataFrame

Compute deficit matrix given patient clinical scores.

Source code in src\ai_cdss\data.py
def compute_deficit_matrix(self, patient_df: pd.DataFrame) -> pd.DataFrame:
    """Compute deficit matrix given patient clinical scores."""
    max_subscales = [self.scales_dict.get(scale, None) for scale in patient_df.columns]
    if None in max_subscales:
        missing_subscales = [
            scale for scale, max_val in zip(patient_df.columns, max_subscales)
            if max_val is None
        ]
        raise ValueError(f"Missing max values for subscales: {missing_subscales}")

    deficit_matrix = 1 - (
        patient_df / pd.Series(max_subscales, index=patient_df.columns)
    )
    deficit_matrix.rename(self.scales_dict._keys, axis=1, inplace=True)
    return deficit_matrix

ProtocolToClinicalMapper

ProtocolToClinicalMapper(
    mapping_yaml_path: Optional[str] = None,
)

Protocol attribute frame → clinical-scale frame.

Reads the protocol→subscale mapping from YAML (default: config/mapping.yaml). Each clinical scale becomes a column whose value is agg_func (default mean) over the protocol attributes that map to it.

Source code in src\ai_cdss\data.py
def __init__(self, mapping_yaml_path: Optional[str] = None) -> None:
    if mapping_yaml_path:
        self.mapping_path = Path(mapping_yaml_path)
    else:
        self.mapping_path = importlib.resources.files(config) / Path(MAPPING_YAML)
    if not self.mapping_path.exists():
        raise FileNotFoundError(f"Mapping YAML file not found at {self.mapping_path}")
    self.mapping = MultiKeyDict.from_yaml(self.mapping_path)

map_protocol_features

map_protocol_features(
    protocol_df: DataFrame, agg_func=mean
) -> DataFrame

Map protocol-level features into clinical scales.

Source code in src\ai_cdss\data.py
def map_protocol_features(
    self, protocol_df: pd.DataFrame, agg_func=np.mean,
) -> pd.DataFrame:
    """Map protocol-level features into clinical scales."""
    df_clinical = pd.DataFrame(index=protocol_df.index)
    for clinical_scale, features in self.mapping.items():
        df_clinical[clinical_scale] = protocol_df[features].apply(agg_func, axis=1)
    df_clinical.index = protocol_df[PROTOCOL_ID]
    return df_clinical

PrescriptionStore

Bases: Protocol

Write interface every prescription sink must implement.

Implementations: RGSPrescriptionStore production — DB via rgs_interface InMemoryPrescriptionStore tests / synthetic backtests (future)

RGSPrescriptionStore

RGSPrescriptionStore(
    db: Optional[DatabaseInterface] = None,
)

Production PrescriptionStore backed by DatabaseInterface.

Constructed with the same interface instance as RGSCohortRepository (see CDSS.__init__) so read + write share one DB connection.

Source code in src\ai_cdss\data.py
def __init__(self, db: Optional[DatabaseInterface] = None) -> None:
    self.interface = db or DatabaseInterface()

already_prescribed

already_prescribed(
    patient_id: int, week_start: date
) -> bool

True if prescription_staging already has any row (any STATUS) for (patient_id, week_start). Swallows query errors as not-prescribed — a failed check must not block a fresh run.

Source code in src\ai_cdss\data.py
def already_prescribed(self, patient_id: int, week_start: date) -> bool:
    """True if `prescription_staging` already has any row (any STATUS)
    for `(patient_id, week_start)`. Swallows query errors as
    not-prescribed — a failed check must not block a fresh run."""
    engine = getattr(self.interface, "engine", None)
    if engine is None:
        return False
    sql = (
        "SELECT COUNT(*) AS n FROM prescription_staging "
        "WHERE PATIENT_ID = :pid AND DATE(STARTING_DATE) = :wk"
    )
    try:
        df = self.interface._fetch(
            query=sql,
            params={"pid": int(patient_id), "wk": week_start.isoformat()},
        )
        return bool(df is not None and not df.empty and int(df.iloc[0]["n"]) > 0)
    except Exception:
        logger.exception(
            "Duplication check failed for patient %s; treating as "
            "not-prescribed.", patient_id,
        )
        return False

read_yaml

read_yaml(path: str | Path) -> dict

Load a YAML file into a plain dict.

Source code in src\ai_cdss\data.py
def read_yaml(path: str | Path) -> dict:
    """Load a YAML file into a plain dict."""
    with open(path, "r") as f:
        return yaml.safe_load(f)

read_csv

read_csv(
    file_path: Optional[Path | str] = None,
    default_filename: Optional[str] = None,
) -> DataFrame

Read a CSV from file_path (if given) or from DEFAULT_DATA_DIR / default_filename. Copies the file to the default directory if it came from elsewhere — convenient for caching.

Source code in src\ai_cdss\data.py
def read_csv(
    file_path: Optional[Path | str] = None,
    default_filename: Optional[str] = None,
) -> pd.DataFrame:
    """Read a CSV from `file_path` (if given) or from `DEFAULT_DATA_DIR /
    default_filename`. Copies the file to the default directory if it
    came from elsewhere — convenient for caching."""
    if file_path is not None:
        file_path = Path(file_path)
    else:
        if default_filename is None:
            raise ValueError("Either file_path or default_filename must be provided.")
        file_path = DEFAULT_DATA_DIR / default_filename

    if not file_path.exists():
        raise FileNotFoundError(
            f"File not found: {file_path}. Ensure the correct path is specified."
        )

    try:
        df = pd.read_csv(file_path, index_col=0)
        default_file_path = DEFAULT_DATA_DIR / file_path.name
        if file_path.parent != DEFAULT_DATA_DIR:
            DEFAULT_DATA_DIR.mkdir(parents=True, exist_ok=True)
            if default_file_path.exists():
                logger.warning(
                    "Overwriting existing file in default directory: %s", default_file_path,
                )
            shutil.copy(file_path, default_file_path)
            logger.info("File copied to default directory: %s", default_file_path)
        return df
    except Exception as e:
        raise ValueError(f"Error reading {file_path}: {e}") from e

decode_subscales

decode_subscales(
    row: Series,
    subscales_column: str = CLINICAL_SCORES,
    id_column: str = PATIENT_ID,
) -> Series

Decode the latest clinical-subscale evaluation from the JSON- encoded CLINICAL_SCORES column into a flat Series.

The DB stores patient subscales as a JSON array of evaluations. Take the most recent entry ([-1]), keep only nested-dict entries (subscale groups, dropping metadata like evaluation_date), flatten via pd.json_normalize.

Source code in src\ai_cdss\data.py
def decode_subscales(
    row: pd.Series,
    subscales_column: str = CLINICAL_SCORES,
    id_column: str = PATIENT_ID,
) -> pd.Series:
    """Decode the latest clinical-subscale evaluation from the JSON-
    encoded `CLINICAL_SCORES` column into a flat Series.

    The DB stores patient subscales as a JSON array of evaluations.
    Take the most recent entry (`[-1]`), keep only nested-dict entries
    (subscale groups, dropping metadata like `evaluation_date`), flatten
    via `pd.json_normalize`.
    """
    data = json.loads(row[subscales_column])[-1]
    subscales = {k: v for k, v in data.items() if isinstance(v, dict)}
    flat = pd.json_normalize(subscales).iloc[0]
    flat[id_column] = row[id_column]
    return flat

load_whitelist

load_whitelist(
    path: Optional[str | Path] = None,
) -> List[int]

Load the AISN-trial-approved protocol set from a YAML config.

Replaces the v0.3.1 ProtocolWhitelistService class — it was 14 lines of class for one YAML read.

Source code in src\ai_cdss\data.py
def load_whitelist(path: Optional[str | Path] = None) -> List[int]:
    """Load the AISN-trial-approved protocol set from a YAML config.

    Replaces the v0.3.1 `ProtocolWhitelistService` class — it was 14
    lines of class for one YAML read.
    """
    if path is None:
        path = importlib.resources.files(config) / Path(PROTOCOL_WHITELIST_YAML)
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"Whitelist YAML not found at {path}")
    return read_yaml(path)["recommendations"]["allowed_protocols"]