Custom Dataset Example

This page turns the synthetic custom-dataset example into a notebook-style walkthrough. It starts from a raw pandas dataframe, parses the feature types with OCEAN, trains a random forest, and explains one class-0 prediction with the constraint-programming backend.

Why this example matters

The packaged dataset loaders are useful when you want a quick start, but most real integrations begin with a dataframe that you already own. This example shows the full workflow on mixed feature types:

  • ordered discrete values through credit_lines,

  • binary flags through owns_home and has_guarantor,

  • continuous ratios through income_ratio, debt_ratio, and savings_ratio,

  • unordered nominal features through job_type and region.

Cell 1: Build a mixed-type dataframe

import numpy as np
import pandas as pd

rng = np.random.default_rng(42)
raw = pd.DataFrame({
    "credit_lines": rng.choice([0, 1, 2, 4], size=300),
    "owns_home": rng.integers(0, 2, size=300),
    "has_guarantor": rng.integers(0, 2, size=300),
    "income_ratio": rng.uniform(-0.4, 0.8, size=300),
    "debt_ratio": rng.uniform(0.0, 1.0, size=300),
    "savings_ratio": rng.uniform(-0.5, 0.6, size=300),
    "job_type": rng.choice(
        ["office", "manual", "service", "student"],
        size=300,
    ),
    "region": rng.choice(
        ["north", "south", "east", "west"],
        size=300,
    ),
})

score = (
    (raw["credit_lines"] >= 2).astype(int)
    + raw["owns_home"].astype(int)
    + raw["has_guarantor"].astype(int)
    + (raw["income_ratio"] > 0.1).astype(int)
    + (raw["savings_ratio"] > 0.0).astype(int)
    + raw["job_type"].isin(["office", "service"]).astype(int)
    + raw["region"].isin(["north", "east"]).astype(int)
    - (raw["debt_ratio"] > 0.55).astype(int)
)
target = (score >= 4).astype(int).rename("approved")

Cell 2: Parse the features with OCEAN

from ocean.feature import parse_features

data, mapper = parse_features(raw, discretes=("credit_lines",))
print(data.columns)
MultiIndex([(  'credit_lines',        ''),
            (     'owns_home',        ''),
            ( 'has_guarantor',        ''),
            (  'income_ratio',        ''),
            (    'debt_ratio',        ''),
            ( 'savings_ratio',        ''),
            (      'job_type',  'manual'),
            (      'job_type',  'office'),
            (      'job_type', 'service'),
            (      'job_type', 'student'),
            (        'region',    'east'),
            (        'region',   'north'),
            (        'region',   'south'),
            (        'region',    'west')],
           )

The important part is that credit_lines stays ordered and numeric, while job_type and region expand into one-hot blocks.

Cell 3: Fit a classifier and choose a query

import pandas as pd
from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(
    n_estimators=40,
    max_depth=4,
    random_state=42,
)
model.fit(data, target)

predictions = pd.Series(model.predict(data), index=data.index)
query_index = predictions[predictions == 0].index[0]
query = data.loc[query_index].to_numpy(dtype=float).flatten()
query_frame = data.loc[[query_index]]
raw_query = raw.loc[query_index]

print(raw_query)
print()
print("Model prediction:", int(model.predict(query_frame).item()))
credit_lines            0
owns_home               0
has_guarantor           1
income_ratio    -0.349179
debt_ratio       0.260349
savings_ratio    0.234634
job_type          student
region               west
Name: 0, dtype: object

Model prediction: 0

Cell 4: Explain the query

from ocean import ConstraintProgrammingExplainer

explainer = ConstraintProgrammingExplainer(model, mapper=mapper)
explanation = explainer.explain(
    query,
    y=1,
    norm=1,
    max_time=10,
    num_workers=1,
    random_seed=42,
)
if explanation is None:
    raise RuntimeError("No counterfactual was found for the synthetic example.")

counterfactual_frame = pd.DataFrame(
    [explanation.to_numpy()],
    columns=data.columns,
)

print("Target class:", 1)
print("Counterfactual prediction:", int(model.predict(counterfactual_frame).item()))
Target class: 1
Counterfactual prediction: 1

Cell 5: Inspect the decoded explanation

print(explanation)
Explanation:
credit_lines   : 0.0
owns_home      : 0
has_guarantor  : 1
income_ratio   : -0.2833683341741562
debt_ratio     : -0.1887158378958702
savings_ratio  : 0.29842646420001984
job_type       : student
region         : north

This decoded view is usually the most readable one: categorical one-hot blocks are mapped back to labels, and the keys match the original dataframe columns.

Cell 6: Inspect the processed vector and the final distance

print(explanation.to_series())
print()
print("Distance:", explainer.get_distance())
credit_lines              0.000000
owns_home                 0.000000
has_guarantor             1.000000
income_ratio             -0.283368
debt_ratio               -0.188716
savings_ratio             0.298426
job_type       manual     0.000000
               office     0.000000
               service    0.000000
               student    1.000000
region         east       0.000000
               north      1.000000
               south      0.000000
               west       0.000000
dtype: float64

Distance: 1.3566914800296987

get_distance() is the user-facing metric to report here: it reconstructs the post-processed \(L_1\) distance between the original query and the decoded counterfactual, including the half-weight treatment for one-hot blocks.

Full script

If you want the exact runnable version behind this page:

examples/custom_dataset.py
  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING
  4
  5import numpy as np
  6import pandas as pd
  7from sklearn.ensemble import RandomForestClassifier
  8
  9from ocean import ConstraintProgrammingExplainer
 10from ocean.feature import parse_features
 11
 12if TYPE_CHECKING:
 13    from ocean.abc import Mapper
 14    from ocean.feature import Feature
 15
 16
 17def build_dataset(
 18    seed: int = 42,
 19    n_samples: int = 300,
 20) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series[int], Mapper[Feature]]:
 21    rng = np.random.default_rng(seed)
 22
 23    raw = pd.DataFrame({
 24        "credit_lines": rng.choice([0, 1, 2, 4], size=n_samples),
 25        "owns_home": rng.integers(0, 2, size=n_samples),
 26        "has_guarantor": rng.integers(0, 2, size=n_samples),
 27        "income_ratio": rng.uniform(-0.4, 0.8, size=n_samples),
 28        "debt_ratio": rng.uniform(0.0, 1.0, size=n_samples),
 29        "savings_ratio": rng.uniform(-0.5, 0.6, size=n_samples),
 30        "job_type": rng.choice(
 31            ["office", "manual", "service", "student"],
 32            size=n_samples,
 33        ),
 34        "region": rng.choice(
 35            ["north", "south", "east", "west"],
 36            size=n_samples,
 37        ),
 38    })
 39
 40    score = (
 41        (raw["credit_lines"] >= 2).astype(int)
 42        + raw["owns_home"].astype(int)
 43        + raw["has_guarantor"].astype(int)
 44        + (raw["income_ratio"] > 0.1).astype(int)
 45        + (raw["savings_ratio"] > 0.0).astype(int)
 46        + raw["job_type"].isin(["office", "service"]).astype(int)
 47        + raw["region"].isin(["north", "east"]).astype(int)
 48        - (raw["debt_ratio"] > 0.55).astype(int)
 49    )
 50    target = (score >= 4).astype(int).rename("approved")
 51
 52    data, mapper = parse_features(raw, discretes=("credit_lines",))
 53    return raw, data, target, mapper
 54
 55
 56def main() -> None:
 57    raw, data, target, mapper = build_dataset()
 58
 59    model = RandomForestClassifier(
 60        n_estimators=40,
 61        max_depth=4,
 62        random_state=42,
 63    )
 64    model.fit(data, target)
 65
 66    predictions = pd.Series(model.predict(data), index=data.index)
 67    query_index = predictions[predictions == 0].index[0]
 68    query = data.loc[query_index].to_numpy(dtype=float).flatten()
 69    query_frame = data.loc[[query_index]]
 70    raw_query = raw.loc[query_index]
 71
 72    explainer = ConstraintProgrammingExplainer(model, mapper=mapper)
 73    explanation = explainer.explain(
 74        query,
 75        y=1,
 76        norm=1,
 77        max_time=10,
 78        num_workers=1,
 79        random_seed=42,
 80    )
 81    if explanation is None:
 82        msg = "No counterfactual was found for the synthetic example."
 83        raise RuntimeError(msg)
 84    counterfactual_frame = pd.DataFrame(
 85        [explanation.to_numpy()],
 86        columns=data.columns,
 87    )
 88
 89    print("Original raw instance:")
 90    print(raw_query)
 91    print()
 92    print("Model prediction:", int(model.predict(query_frame).item()))
 93    print("Target class:", 1)
 94    print(
 95        "Counterfactual prediction:",
 96        int(model.predict(counterfactual_frame).item()),
 97    )
 98    print("Counterfactual values:")
 99    print(explanation)
100    print()
101    print("Counterfactual vector:")
102    print(explanation.to_series())
103    print()
104    print("Distance:", explainer.get_distance())
105
106
107if __name__ == "__main__":
108    main()