Skip to content

Data Splits

In most data science workflows, it's common to split data into different subsets for analysis and comparison. In support of this, DataInterface subclasses allow you to specify and split your data based on specific logic that is provided to a DataSplit.

Split types

Column Name and Value

  • Split data based on a column value.
  • Supports inequality signs.
  • Works with Pandas and Polars DataFrames.

Example

import polars as pl
from opsml import PolarsData, DataSplit, CardInfo

info = CardInfo(name="data", repository="mlops", contact="user@mlops.com")

df = pl.DataFrame(
    {
        "foo": [1, 2, 3, 4, 5, 6],
        "bar": ["a", "b", "c", "d", "e", "f"],
        "y": [1, 2, 3, 4, 5, 6],
    }
)

interface = PolarsData(
    info=info,
    data=df,
    data_splits = [
        DataSplit(label="train", column_name="foo", column_value=6, inequality="<"),
        DataSplit(label="test", column_name="foo", column_value=6)
    ]

)

splits = interface.split_data()
assert splits["train"].X.shape[0] == 5
assert splits["test"].X.shape[0] == 1

Indices

  • Split data based on pre-defined indices
  • Works with NDArray, pyarrow.Table, pandas.DataFrame and polars.DataFrame
import numpy as np
from opsml import NumpyData, DataSplit, CardInfo

info = CardInfo(name="data", repository="mlops", contact="user@mlops.com")

data = np.random.rand(10, 10)

interface = NumpyData(
    info=info,
    data=data,
    data_splits = [
        DataSplit(label="train", indices=[0,1,5])
    ]

)

splits = interface.split_data()
assert splits["train"].X.shape[0] == 3

Start and Stop Slicing

  • Split data based on row slices with a start and stop index
  • Works with NDArray, pyarrow.Table, pandas.DataFrame and polars.DataFrame
import numpy as np
from opsml import NumpyData, DataSplit, CardInfo

info = CardInfo(name="data", repository="mlops", contact="user@mlops.com")

data = np.random.rand(10, 10)

interface = NumpyData(
    info=info,
    data=data,
    data_splits = [
        DataSplit(label="train", start=0, stop=3)
    ]

)

splits = interface.split_data()
assert splits["train"].X.shape[0] == 3

opsml.DataSplit

Bases: BaseModel

Creates a data split based on the provided logic.

Parameters:

Name Type Description Default
label

Label for the split

required
column_name

Column name to split on

required
column_value

Column value to split on. Can be a string, float, int, or timestamp.

required
inequality

Inequality sign to split on

required
start

Start index to split on

required
stop

Stop index to split on

required
indices

List of indices to split on

required
Source code in opsml/data/splitter.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class DataSplit(BaseModel):
    """Creates a data split based on the provided logic.

    Args:
        label:
            Label for the split
        column_name:
            Column name to split on
        column_value:
            Column value to split on. Can be a string, float, int, or timestamp.
        inequality:
            Inequality sign to split on
        start:
            Start index to split on
        stop:
            Stop index to split on
        indices:
            List of indices to split on
        column_type
            column_type of column_value. Automatically set

    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

    label: str
    column_name: Optional[str] = None
    column_value: Optional[Union[str, float, int, pd.Timestamp]] = None
    inequality: Optional[str] = None
    start: Optional[int] = None
    stop: Optional[int] = None
    indices: Optional[List[int]] = None
    column_type: str = "builtin"

    @model_validator(mode="before")
    @classmethod
    def check_timestamp(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
        column_value = model_args.get("column_value")

        if column_value is not None:
            if model_args.get("column_type") == "timestamp" and not isinstance(column_value, pd.Timestamp):
                model_args["column_value"] = pd.Timestamp(column_value)

            if isinstance(column_value, pd.Timestamp):
                model_args["column_type"] = "timestamp"

        return model_args

    @field_validator("indices", mode="before")
    @classmethod
    def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]:
        """Pre to convert indices to list if not None"""

        if value is not None and not isinstance(value, list):
            value = list(value)

        return value

    @field_validator("inequality", mode="before")
    @classmethod
    def trim_whitespace(cls, value: str) -> str:
        """Trims whitespace from inequality signs"""

        if value is not None:
            value = value.strip()

        return value

    @field_serializer("column_value", mode="plain")
    def serialize_column_value(
        self,
        column_value: Optional[Union[str, float, int, pd.Timestamp]],
        _info: FieldSerializationInfo,
    ) -> Optional[Union[str, float, int]]:
        """Serializes pd.timestamp to str. This is used when saving the data split as a JSON file

        Args:
            column_value:
                Column value to serialize

        Returns:
            Union[str, float, int]: Serialized column value
        """

        if isinstance(column_value, pd.Timestamp):
            return str(column_value)
        return column_value

convert_to_list(value) classmethod

Pre to convert indices to list if not None

Source code in opsml/data/splitter.py
79
80
81
82
83
84
85
86
87
@field_validator("indices", mode="before")
@classmethod
def convert_to_list(cls, value: Optional[List[int]]) -> Optional[List[int]]:
    """Pre to convert indices to list if not None"""

    if value is not None and not isinstance(value, list):
        value = list(value)

    return value

serialize_column_value(column_value, _info)

Serializes pd.timestamp to str. This is used when saving the data split as a JSON file

Parameters:

Name Type Description Default
column_value Optional[Union[str, float, int, Timestamp]]

Column value to serialize

required

Returns:

Type Description
Optional[Union[str, float, int]]

Union[str, float, int]: Serialized column value

Source code in opsml/data/splitter.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@field_serializer("column_value", mode="plain")
def serialize_column_value(
    self,
    column_value: Optional[Union[str, float, int, pd.Timestamp]],
    _info: FieldSerializationInfo,
) -> Optional[Union[str, float, int]]:
    """Serializes pd.timestamp to str. This is used when saving the data split as a JSON file

    Args:
        column_value:
            Column value to serialize

    Returns:
        Union[str, float, int]: Serialized column value
    """

    if isinstance(column_value, pd.Timestamp):
        return str(column_value)
    return column_value

trim_whitespace(value) classmethod

Trims whitespace from inequality signs

Source code in opsml/data/splitter.py
89
90
91
92
93
94
95
96
97
@field_validator("inequality", mode="before")
@classmethod
def trim_whitespace(cls, value: str) -> str:
    """Trims whitespace from inequality signs"""

    if value is not None:
        value = value.strip()

    return value