Skip to content

Onnx Args

Some model interfaces require extra arguments when converting to onnx. These arguments can be passed to the onnx_args argument of the ModelInterface class.

TorchOnnxArgs

TorchOnnxArgs is the optional onnx args class for TorchModel and LightningModel. When not supplied, a default TorchOnnxArgs class is used. For more information on the arguments, please refer to the torch.onnx documentation.

opsml.TorchOnnxArgs

Bases: BaseModel

Optional arguments to pass to torch when converting to onnx

Parameters:

Name Type Description Default
input_names

Optional list containing input names for model inputs.

required
output_names

Optional list containing output names for model outputs.

required
dynamic_axes

Optional PyTorch attribute that defines dynamic axes

required
constant_folding

Whether to use constant folding optimization. Default is True

required
Source code in opsml/types/model.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class TorchOnnxArgs(BaseModel):
    """Optional arguments to pass to torch when converting to onnx

    Args:
        input_names:
            Optional list containing input names for model inputs.
        output_names:
            Optional list containing output names for model outputs.
        dynamic_axes:
            Optional PyTorch attribute that defines dynamic axes
        constant_folding:
            Whether to use constant folding optimization. Default is True
    """

    input_names: List[str]
    output_names: List[str]
    dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None
    do_constant_folding: bool = True
    export_params: bool = True
    verbose: bool = False
    options: Optional[Dict[str, Any]] = None

HuggingFaceOnnxArgs

HuggingFaceOnnxArgs is the REQUIRED onnx args class for HuggingFaceModel when converting a model to onnx format. HuggingFaceOnnxArgs is a custom object that allows you to specify how optimum should convert your model to onnx.

Required Arguments

ort_type
Optimum onnx class name as defined in HuggingFaceORTModel
provider
Onnx runtime provider to user. Defaults to CPUExecutionProvider
quantize
Whether or not to quantize the model. Defaults to False. If True a quantization config is required.
config
Optional config for conversion. Can be one of AutoQuantizationConfig, ORTConfig or QuantizationConfig. See optimum for more details.

opsml.HuggingFaceOnnxArgs

Bases: BaseModel

Optional Args to use with a huggingface model

Parameters:

Name Type Description Default
ort_type

Optimum onnx class name

required
provider

Onnx runtime provider to use

required
config

Optional optimum config to use

required
Source code in opsml/types/model.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class HuggingFaceOnnxArgs(BaseModel):
    """Optional Args to use with a huggingface model

    Args:
        ort_type:
            Optimum onnx class name
        provider:
            Onnx runtime provider to use
        config:
            Optional optimum config to use
    """

    ort_type: str
    provider: str = "CPUExecutionProvider"
    quantize: bool = False
    config: Optional[Any] = None

    @field_validator("ort_type", mode="before")
    @classmethod
    def check_ort_type(cls, ort_type: str) -> str:
        """Validates onnx runtime model type"""
        if ort_type not in list(HuggingFaceORTModel):
            raise ValueError(f"Optimum model type {ort_type} is not supported")
        return ort_type

    @field_validator("config", mode="before")
    @classmethod
    def check_config(cls, config: Optional[Any] = None) -> Optional[Any]:
        """Check that optimum config is valid"""

        if config is None:
            return config

        from optimum.onnxruntime import (
            AutoQuantizationConfig,
            ORTConfig,
            QuantizationConfig,
        )

        assert isinstance(
            config,
            (
                AutoQuantizationConfig,
                ORTConfig,
                QuantizationConfig,
            ),
        ), "config must be a valid optimum config"

        return config

check_config(config=None) classmethod

Check that optimum config is valid

Source code in opsml/types/model.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
@field_validator("config", mode="before")
@classmethod
def check_config(cls, config: Optional[Any] = None) -> Optional[Any]:
    """Check that optimum config is valid"""

    if config is None:
        return config

    from optimum.onnxruntime import (
        AutoQuantizationConfig,
        ORTConfig,
        QuantizationConfig,
    )

    assert isinstance(
        config,
        (
            AutoQuantizationConfig,
            ORTConfig,
            QuantizationConfig,
        ),
    ), "config must be a valid optimum config"

    return config

check_ort_type(ort_type) classmethod

Validates onnx runtime model type

Source code in opsml/types/model.py
220
221
222
223
224
225
226
@field_validator("ort_type", mode="before")
@classmethod
def check_ort_type(cls, ort_type: str) -> str:
    """Validates onnx runtime model type"""
    if ort_type not in list(HuggingFaceORTModel):
        raise ValueError(f"Optimum model type {ort_type} is not supported")
    return ort_type