opsml.model.interfaces.xgb
1from pathlib import Path 2from typing import Any, Dict, Optional, Union 3 4import joblib 5import pandas as pd 6from numpy.typing import NDArray 7from pydantic import model_validator 8 9from opsml.helpers.logging import ArtifactLogger 10from opsml.helpers.utils import get_class_name 11from opsml.model import ModelInterface 12from opsml.model.interfaces.base import get_model_args, get_processor_name 13from opsml.types import CommonKwargs, ModelReturn, Suffix, TrainedModelType 14 15logger = ArtifactLogger.get_logger() 16 17try: 18 from xgboost import Booster, DMatrix, XGBModel 19 20 class XGBoostModel(ModelInterface): 21 """Model interface for XGBoost model class. Currently, only Sklearn flavor of XGBoost 22 regressor and classifier are supported. 23 24 Args: 25 model: 26 XGBoost model. Can be either a Booster or XGBModel. 27 preprocessor: 28 Optional preprocessor 29 sample_data: 30 Sample data to be used for type inference and ONNX conversion/validation. 31 This should match exactly what the model expects as input. 32 task_type: 33 Task type for model. Defaults to undefined. 34 model_type: 35 Optional model type. This is inferred automatically. 36 preprocessor_name: 37 Optional preprocessor. This is inferred automatically if a 38 preprocessor is provided. 39 40 Returns: 41 XGBoostModel 42 """ 43 44 model: Optional[Union[Booster, XGBModel]] = None 45 sample_data: Optional[Union[pd.DataFrame, NDArray[Any], DMatrix]] = None 46 preprocessor: Optional[Any] = None 47 preprocessor_name: str = CommonKwargs.UNDEFINED.value 48 49 @property 50 def model_class(self) -> str: 51 if "Booster" in self.model_type: 52 return TrainedModelType.XGB_BOOSTER.value 53 return TrainedModelType.SKLEARN_ESTIMATOR.value 54 55 @classmethod 56 def _get_sample_data(cls, sample_data: Any) -> Union[pd.DataFrame, NDArray[Any], DMatrix]: 57 """Check sample data and returns one record to be used 58 during type inference and ONNX conversion/validation. 59 60 Returns: 61 Sample data with only one record 62 """ 63 if isinstance(sample_data, DMatrix): 64 return sample_data.slice([0]) 65 return super()._get_sample_data(sample_data) 66 67 @model_validator(mode="before") 68 @classmethod 69 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 70 model = model_args.get("model") 71 72 if model_args.get("modelcard_uid", False): 73 return model_args 74 75 model, _, bases = get_model_args(model) 76 77 if isinstance(model, XGBModel): 78 model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__ 79 80 elif isinstance(model, Booster): 81 model_args[CommonKwargs.MODEL_TYPE.value] = "Booster" 82 83 else: 84 for base in bases: 85 if "sklearn" in base: 86 model_args[CommonKwargs.MODEL_TYPE.value] = "subclass" 87 88 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 89 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 90 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 91 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 92 model_args.get(CommonKwargs.PREPROCESSOR.value), 93 ) 94 95 return model_args 96 97 def save_model(self, path: Path) -> None: 98 """Saves lgb model according to model format. Booster models are saved to text. 99 Sklearn models are saved via joblib. 100 101 Args: 102 path: 103 base path to save model to 104 """ 105 assert self.model is not None, "No model found" 106 if isinstance(self.model, Booster): 107 self.model.save_model(path) 108 109 else: 110 super().save_model(path) 111 112 def load_model(self, path: Path, **kwargs: Any) -> None: 113 """Loads lightgbm booster or sklearn model 114 115 116 Args: 117 path: 118 base path to load from 119 **kwargs: 120 Additional keyword arguments 121 """ 122 123 if self.model_type == TrainedModelType.LGBM_BOOSTER.value: 124 self.model = Booster(model_file=path) 125 else: 126 super().load_model(path) 127 128 def save_preprocessor(self, path: Path) -> None: 129 """Saves preprocessor to path if present. Base implementation use Joblib 130 131 Args: 132 path: 133 Pathlib object 134 """ 135 assert self.preprocessor is not None, "No preprocessor detected in interface" 136 joblib.dump(self.preprocessor, path) 137 138 def load_preprocessor(self, path: Path) -> None: 139 """Load preprocessor from pathlib object 140 141 Args: 142 path: 143 Pathlib object 144 """ 145 self.preprocessor = joblib.load(path) 146 147 def save_onnx(self, path: Path) -> ModelReturn: 148 """Saves the onnx model 149 150 Args: 151 path: 152 Path to save 153 154 Returns: 155 ModelReturn 156 """ 157 158 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 159 logger.warning("ONNX conversion for XGBoost Booster is not supported") 160 161 return super().save_onnx(path) 162 163 def save_sample_data(self, path: Path) -> None: 164 """Serialized and save sample data to path. 165 166 Args: 167 path: 168 Pathlib object 169 """ 170 if isinstance(self.sample_data, DMatrix): 171 self.sample_data.save_binary(path) 172 173 else: 174 joblib.dump(self.sample_data, path) 175 176 def load_sample_data(self, path: Path) -> None: 177 """Serialized and save sample data to path. 178 179 Args: 180 path: 181 Pathlib object 182 """ 183 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 184 self.sample_data = DMatrix(path) 185 else: 186 self.sample_data = joblib.load(path) 187 188 @property 189 def model_suffix(self) -> str: 190 if self.model_type == TrainedModelType.XGB_BOOSTER.value: 191 return Suffix.JSON.value 192 193 return super().model_suffix 194 195 @property 196 def preprocessor_suffix(self) -> str: 197 """Returns suffix for storage""" 198 return Suffix.JOBLIB.value 199 200 @property 201 def data_suffix(self) -> str: 202 """Returns suffix for storage""" 203 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 204 return Suffix.DMATRIX.value 205 return Suffix.JOBLIB.value 206 207 @staticmethod 208 def name() -> str: 209 return XGBoostModel.__name__ 210 211except ModuleNotFoundError: 212 from opsml.model.interfaces.backups import XGBoostModelNoModule as XGBoostModel
logger =
<builtins.Logger object>
21 class XGBoostModel(ModelInterface): 22 """Model interface for XGBoost model class. Currently, only Sklearn flavor of XGBoost 23 regressor and classifier are supported. 24 25 Args: 26 model: 27 XGBoost model. Can be either a Booster or XGBModel. 28 preprocessor: 29 Optional preprocessor 30 sample_data: 31 Sample data to be used for type inference and ONNX conversion/validation. 32 This should match exactly what the model expects as input. 33 task_type: 34 Task type for model. Defaults to undefined. 35 model_type: 36 Optional model type. This is inferred automatically. 37 preprocessor_name: 38 Optional preprocessor. This is inferred automatically if a 39 preprocessor is provided. 40 41 Returns: 42 XGBoostModel 43 """ 44 45 model: Optional[Union[Booster, XGBModel]] = None 46 sample_data: Optional[Union[pd.DataFrame, NDArray[Any], DMatrix]] = None 47 preprocessor: Optional[Any] = None 48 preprocessor_name: str = CommonKwargs.UNDEFINED.value 49 50 @property 51 def model_class(self) -> str: 52 if "Booster" in self.model_type: 53 return TrainedModelType.XGB_BOOSTER.value 54 return TrainedModelType.SKLEARN_ESTIMATOR.value 55 56 @classmethod 57 def _get_sample_data(cls, sample_data: Any) -> Union[pd.DataFrame, NDArray[Any], DMatrix]: 58 """Check sample data and returns one record to be used 59 during type inference and ONNX conversion/validation. 60 61 Returns: 62 Sample data with only one record 63 """ 64 if isinstance(sample_data, DMatrix): 65 return sample_data.slice([0]) 66 return super()._get_sample_data(sample_data) 67 68 @model_validator(mode="before") 69 @classmethod 70 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 71 model = model_args.get("model") 72 73 if model_args.get("modelcard_uid", False): 74 return model_args 75 76 model, _, bases = get_model_args(model) 77 78 if isinstance(model, XGBModel): 79 model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__ 80 81 elif isinstance(model, Booster): 82 model_args[CommonKwargs.MODEL_TYPE.value] = "Booster" 83 84 else: 85 for base in bases: 86 if "sklearn" in base: 87 model_args[CommonKwargs.MODEL_TYPE.value] = "subclass" 88 89 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 90 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 91 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 92 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 93 model_args.get(CommonKwargs.PREPROCESSOR.value), 94 ) 95 96 return model_args 97 98 def save_model(self, path: Path) -> None: 99 """Saves lgb model according to model format. Booster models are saved to text. 100 Sklearn models are saved via joblib. 101 102 Args: 103 path: 104 base path to save model to 105 """ 106 assert self.model is not None, "No model found" 107 if isinstance(self.model, Booster): 108 self.model.save_model(path) 109 110 else: 111 super().save_model(path) 112 113 def load_model(self, path: Path, **kwargs: Any) -> None: 114 """Loads lightgbm booster or sklearn model 115 116 117 Args: 118 path: 119 base path to load from 120 **kwargs: 121 Additional keyword arguments 122 """ 123 124 if self.model_type == TrainedModelType.LGBM_BOOSTER.value: 125 self.model = Booster(model_file=path) 126 else: 127 super().load_model(path) 128 129 def save_preprocessor(self, path: Path) -> None: 130 """Saves preprocessor to path if present. Base implementation use Joblib 131 132 Args: 133 path: 134 Pathlib object 135 """ 136 assert self.preprocessor is not None, "No preprocessor detected in interface" 137 joblib.dump(self.preprocessor, path) 138 139 def load_preprocessor(self, path: Path) -> None: 140 """Load preprocessor from pathlib object 141 142 Args: 143 path: 144 Pathlib object 145 """ 146 self.preprocessor = joblib.load(path) 147 148 def save_onnx(self, path: Path) -> ModelReturn: 149 """Saves the onnx model 150 151 Args: 152 path: 153 Path to save 154 155 Returns: 156 ModelReturn 157 """ 158 159 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 160 logger.warning("ONNX conversion for XGBoost Booster is not supported") 161 162 return super().save_onnx(path) 163 164 def save_sample_data(self, path: Path) -> None: 165 """Serialized and save sample data to path. 166 167 Args: 168 path: 169 Pathlib object 170 """ 171 if isinstance(self.sample_data, DMatrix): 172 self.sample_data.save_binary(path) 173 174 else: 175 joblib.dump(self.sample_data, path) 176 177 def load_sample_data(self, path: Path) -> None: 178 """Serialized and save sample data to path. 179 180 Args: 181 path: 182 Pathlib object 183 """ 184 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 185 self.sample_data = DMatrix(path) 186 else: 187 self.sample_data = joblib.load(path) 188 189 @property 190 def model_suffix(self) -> str: 191 if self.model_type == TrainedModelType.XGB_BOOSTER.value: 192 return Suffix.JSON.value 193 194 return super().model_suffix 195 196 @property 197 def preprocessor_suffix(self) -> str: 198 """Returns suffix for storage""" 199 return Suffix.JOBLIB.value 200 201 @property 202 def data_suffix(self) -> str: 203 """Returns suffix for storage""" 204 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 205 return Suffix.DMATRIX.value 206 return Suffix.JOBLIB.value 207 208 @staticmethod 209 def name() -> str: 210 return XGBoostModel.__name__
Model interface for XGBoost model class. Currently, only Sklearn flavor of XGBoost regressor and classifier are supported.
Arguments:
- model: XGBoost model. Can be either a Booster or XGBModel.
- preprocessor: Optional preprocessor
- sample_data: Sample data to be used for type inference and ONNX conversion/validation. This should match exactly what the model expects as input.
- task_type: Task type for model. Defaults to undefined.
- model_type: Optional model type. This is inferred automatically.
- preprocessor_name: Optional preprocessor. This is inferred automatically if a preprocessor is provided.
Returns:
XGBoostModel
sample_data: Union[pandas.core.frame.DataFrame, numpy.ndarray[Any, numpy.dtype[Any]], xgboost.core.DMatrix, NoneType]
@model_validator(mode='before')
@classmethod
def
check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]:
68 @model_validator(mode="before") 69 @classmethod 70 def check_model(cls, model_args: Dict[str, Any]) -> Dict[str, Any]: 71 model = model_args.get("model") 72 73 if model_args.get("modelcard_uid", False): 74 return model_args 75 76 model, _, bases = get_model_args(model) 77 78 if isinstance(model, XGBModel): 79 model_args[CommonKwargs.MODEL_TYPE.value] = model.__class__.__name__ 80 81 elif isinstance(model, Booster): 82 model_args[CommonKwargs.MODEL_TYPE.value] = "Booster" 83 84 else: 85 for base in bases: 86 if "sklearn" in base: 87 model_args[CommonKwargs.MODEL_TYPE.value] = "subclass" 88 89 sample_data = cls._get_sample_data(sample_data=model_args[CommonKwargs.SAMPLE_DATA.value]) 90 model_args[CommonKwargs.SAMPLE_DATA.value] = sample_data 91 model_args[CommonKwargs.DATA_TYPE.value] = get_class_name(sample_data) 92 model_args[CommonKwargs.PREPROCESSOR_NAME.value] = get_processor_name( 93 model_args.get(CommonKwargs.PREPROCESSOR.value), 94 ) 95 96 return model_args
def
save_model(self, path: pathlib.Path) -> None:
98 def save_model(self, path: Path) -> None: 99 """Saves lgb model according to model format. Booster models are saved to text. 100 Sklearn models are saved via joblib. 101 102 Args: 103 path: 104 base path to save model to 105 """ 106 assert self.model is not None, "No model found" 107 if isinstance(self.model, Booster): 108 self.model.save_model(path) 109 110 else: 111 super().save_model(path)
Saves lgb model according to model format. Booster models are saved to text. Sklearn models are saved via joblib.
Arguments:
- path: base path to save model to
def
load_model(self, path: pathlib.Path, **kwargs: Any) -> None:
113 def load_model(self, path: Path, **kwargs: Any) -> None: 114 """Loads lightgbm booster or sklearn model 115 116 117 Args: 118 path: 119 base path to load from 120 **kwargs: 121 Additional keyword arguments 122 """ 123 124 if self.model_type == TrainedModelType.LGBM_BOOSTER.value: 125 self.model = Booster(model_file=path) 126 else: 127 super().load_model(path)
Loads lightgbm booster or sklearn model
Arguments:
- path: base path to load from
- **kwargs: Additional keyword arguments
def
save_preprocessor(self, path: pathlib.Path) -> None:
129 def save_preprocessor(self, path: Path) -> None: 130 """Saves preprocessor to path if present. Base implementation use Joblib 131 132 Args: 133 path: 134 Pathlib object 135 """ 136 assert self.preprocessor is not None, "No preprocessor detected in interface" 137 joblib.dump(self.preprocessor, path)
Saves preprocessor to path if present. Base implementation use Joblib
Arguments:
- path: Pathlib object
def
load_preprocessor(self, path: pathlib.Path) -> None:
139 def load_preprocessor(self, path: Path) -> None: 140 """Load preprocessor from pathlib object 141 142 Args: 143 path: 144 Pathlib object 145 """ 146 self.preprocessor = joblib.load(path)
Load preprocessor from pathlib object
Arguments:
- path: Pathlib object
def
save_onnx(self, path: pathlib.Path) -> opsml.types.model.ModelReturn:
148 def save_onnx(self, path: Path) -> ModelReturn: 149 """Saves the onnx model 150 151 Args: 152 path: 153 Path to save 154 155 Returns: 156 ModelReturn 157 """ 158 159 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 160 logger.warning("ONNX conversion for XGBoost Booster is not supported") 161 162 return super().save_onnx(path)
Saves the onnx model
Arguments:
- path: Path to save
Returns:
ModelReturn
def
save_sample_data(self, path: pathlib.Path) -> None:
164 def save_sample_data(self, path: Path) -> None: 165 """Serialized and save sample data to path. 166 167 Args: 168 path: 169 Pathlib object 170 """ 171 if isinstance(self.sample_data, DMatrix): 172 self.sample_data.save_binary(path) 173 174 else: 175 joblib.dump(self.sample_data, path)
Serialized and save sample data to path.
Arguments:
- path: Pathlib object
def
load_sample_data(self, path: pathlib.Path) -> None:
177 def load_sample_data(self, path: Path) -> None: 178 """Serialized and save sample data to path. 179 180 Args: 181 path: 182 Pathlib object 183 """ 184 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 185 self.sample_data = DMatrix(path) 186 else: 187 self.sample_data = joblib.load(path)
Serialized and save sample data to path.
Arguments:
- path: Pathlib object
model_suffix: str
189 @property 190 def model_suffix(self) -> str: 191 if self.model_type == TrainedModelType.XGB_BOOSTER.value: 192 return Suffix.JSON.value 193 194 return super().model_suffix
Returns suffix for storage
preprocessor_suffix: str
196 @property 197 def preprocessor_suffix(self) -> str: 198 """Returns suffix for storage""" 199 return Suffix.JOBLIB.value
Returns suffix for storage
data_suffix: str
201 @property 202 def data_suffix(self) -> str: 203 """Returns suffix for storage""" 204 if self.model_class == TrainedModelType.XGB_BOOSTER.value: 205 return Suffix.DMATRIX.value 206 return Suffix.JOBLIB.value
Returns suffix for storage
model_config =
{'protected_namespaces': ('protect_',), 'arbitrary_types_allowed': True, 'validate_assignment': False, 'validate_default': True, 'extra': 'allow'}
model_fields =
{'model': FieldInfo(annotation=Union[Booster, XGBModel, NoneType], required=False), 'sample_data': FieldInfo(annotation=Union[DataFrame, ndarray[Any, dtype[Any]], DMatrix, NoneType], required=False), 'onnx_model': FieldInfo(annotation=Union[OnnxModel, NoneType], required=False), 'task_type': FieldInfo(annotation=str, required=False, default='undefined'), 'model_type': FieldInfo(annotation=str, required=False, default='undefined'), 'data_type': FieldInfo(annotation=str, required=False, default='undefined'), 'modelcard_uid': FieldInfo(annotation=str, required=False, default=''), 'preprocessor': FieldInfo(annotation=Union[Any, NoneType], required=False), 'preprocessor_name': FieldInfo(annotation=str, required=False, default='undefined')}
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs