Source code for onnx_safetensors._safetensors_io

"""Private module for loading and saving safetensors data to ONNX models."""

from __future__ import annotations

import io
import json
import os
import re
import struct
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeVar

import onnx
import onnx_ir as ir
import safetensors
from tqdm.auto import tqdm

from onnx_safetensors import _tensors

if TYPE_CHECKING:
    pass


_HEADER_SIZE_NUMBER_SIZE = 8
_BYTE_SIZE = 8
# https://github.com/huggingface/safetensors/blob/543243c3017e413584f27ebd4b99c844f62deb34/safetensors/src/tensor.rs#L664
_SAFETENSORS_DTYPE_TO_IR_DTYPE = {
    "BOOL": ir.DataType.BOOL,
    "F4": ir.DataType.FLOAT4E2M1,
    "F8_E5M2": ir.DataType.FLOAT8E5M2,
    "F8_E4M3": ir.DataType.FLOAT8E4M3FN,
    "F8_E8M0": ir.DataType.FLOAT8E8M0,
    "BF16": ir.DataType.BFLOAT16,
    "F16": ir.DataType.FLOAT16,
    "F32": ir.DataType.FLOAT,
    "F64": ir.DataType.DOUBLE,
    "I8": ir.DataType.INT8,
    "I16": ir.DataType.INT16,
    "I32": ir.DataType.INT32,
    "I64": ir.DataType.INT64,
    "U8": ir.DataType.UINT8,
    "U16": ir.DataType.UINT16,
    "U32": ir.DataType.UINT32,
    "U64": ir.DataType.UINT64,
    "C64": ir.DataType.COMPLEX64,
}
_IR_DTYPE_TO_SAFETENSORS_DTYPE = {
    ir.DataType.BOOL: "bool",
    ir.DataType.FLOAT4E2M1: "float4_e2m1fn_x2",
    ir.DataType.FLOAT8E5M2: "float8_e5m2",
    ir.DataType.FLOAT8E4M3FN: "float8_e4m3fn",
    ir.DataType.FLOAT8E8M0: "float8_e8m0",
    ir.DataType.FLOAT8E4M3FNUZ: "uint8",
    ir.DataType.FLOAT8E5M2FNUZ: "uint8",
    ir.DataType.BFLOAT16: "bfloat16",
    ir.DataType.FLOAT16: "float16",
    ir.DataType.FLOAT: "float32",
    ir.DataType.DOUBLE: "float64",
    ir.DataType.INT2: "uint8",
    ir.DataType.INT4: "uint8",
    ir.DataType.INT8: "int8",
    ir.DataType.INT16: "int16",
    ir.DataType.INT32: "int32",
    ir.DataType.INT64: "int64",
    ir.DataType.UINT2: "uint8",
    ir.DataType.UINT4: "uint8",
    ir.DataType.UINT8: "uint8",
    ir.DataType.UINT16: "uint16",
    ir.DataType.UINT32: "uint32",
    ir.DataType.UINT64: "uint64",
    ir.DataType.COMPLEX64: "complex64",
}


TModel = TypeVar("TModel", onnx.ModelProto, ir.Model)


def _parse_size_string(size: int | str) -> int:
    """Parse a size string like '5GB' or '100MB' into bytes.

    Args:
        size: Either an integer representing bytes, or a string like '5GB', '100MB', etc.

    Returns:
        The size in bytes.

    Raises:
        ValueError: If the size string format is invalid.
    """
    if isinstance(size, int):
        return size

    size = size.strip()
    match = re.match(r"(\d+(?:\.\d+)?)\s*([A-Za-z]+)", size)
    if not match:
        raise ValueError(
            f"Invalid size format: {size}. Expected format like '5GB' or '100MB'."
        )

    num_str, unit = match.groups()
    num = float(num_str)

    # Convert to bytes
    unit = unit.upper()
    multipliers = {
        "B": 1,
        "KB": 1000,
        "MB": 1000**2,
        "GB": 1000**3,
        "TB": 1000**4,
    }

    if unit not in multipliers:
        raise ValueError(
            f"Unknown size unit: {unit}. Valid units are: {', '.join(multipliers.keys())}"
        )

    return int(num * multipliers[unit])


def _get_shard_filename(base_name: str, shard_idx: int, total_shards: int) -> str:
    """Generate a filename for a shard.

    Args:
        base_name: The base filename (e.g., 'model.safetensors').
        shard_idx: The index of this shard (1-indexed).
        total_shards: The total number of shards.

    Returns:
        The shard filename (e.g., 'model-00001-of-00003.safetensors').
    """
    if total_shards == 1:
        return base_name

    # Extract extension
    if "." in base_name:
        name, ext = base_name.rsplit(".", 1)
        ext = f".{ext}"
    else:
        name = base_name
        ext = ""

    # Always use 5 digits to follow transformers convention
    return f"{name}-{shard_idx:05d}-of-{total_shards:05d}{ext}"


def _shard_tensors(
    tensors: Sequence[ir.TensorProtocol], max_shard_size_bytes: int | None
) -> list[list[ir.TensorProtocol]]:
    """Shard tensors into multiple files based on max_shard_size_bytes.

    Args:
        tensors: The tensors to shard.
        max_shard_size_bytes: Maximum size for each shard in bytes. When None,
            no sharding is performed.

    Returns:
        A list of tensor lists for each shard.
    """
    if max_shard_size_bytes is None:
        # No sharding
        return [list(tensors)]

    # Shard the tensors by current order
    shards: list[list[ir.TensorProtocol]] = [[]]
    current_shard_size = 0

    for tensor in tensors:
        tensor_size = tensor.nbytes
        # Check if adding this tensor would exceed max_shard_size_bytes
        if (
            current_shard_size + tensor_size > max_shard_size_bytes
            and current_shard_size > 0
        ):
            # Start a new shard
            shards.append([])
            current_shard_size = 0

        shards[-1].append(tensor)
        current_shard_size += tensor_size

    return shards


def _apply_tensors(
    values: Sequence[ir.Value],
    tensors: Mapping[str, ir.TensorProtocol],
    apply_safetensors: bool = False,
):
    """Apply tensors to an ONNX model.

    Args:
        values: Values in the ONNX model.
        tensors: Tensors to apply to the ONNX model.
        apply_safetensors: Whether it is applying safetensors to the ONNX model.
    """
    value_map: dict[str, ir.Value] = {value.name: value for value in values}  # type: ignore[misc]
    for name, tensor in tensors.items():
        if name not in value_map:
            continue
        value = value_map[name]
        model_tensor = value_map[name].const_value
        if model_tensor is not None and apply_safetensors:
            assert isinstance(tensor, ir.ExternalTensor)
            _check_tensors_match(model_tensor, tensor)
            updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor)
        else:
            updated_tensor = tensor
        value.const_value = updated_tensor


[docs] def replace_tensors( model: ir.Model, /, location: str | os.PathLike, base_dir: str | os.PathLike ) -> None: """Replace all tensors in an ONNX model with external data from a safetensors file. .. versionadded:: 1.0 Added the function. Args: model: ONNX model to replace tensors in. location: Path to the safetensors file relative to the ONNX model file. base_dir: Directory where the ONNX model file is stored. """ tensors = _read_safetensors(location, base_dir=base_dir) values = [value for value, _ in _get_value_tensor_pairs(model)] _apply_tensors(values, tensors, apply_safetensors=True)
[docs] def load_file(model: TModel, /, tensor_file: str | os.PathLike) -> TModel: """Load external data into ONNX model from a safetensors file. Args: model: ONNX model. tensor_file: safetensors file to load into ONNX model. .. versionchanged:: 1.0 The return value is now the updated ONNX model instead of a set of loaded tensor names. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model replace_tensors(model_ir, tensor_file, "") model_ir = ir.external_data.load_to_model(model_ir) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
[docs] def load(model: TModel, /, data: bytes) -> TModel: """Load external data into ONNX model from safetensors bytes. Args: model: ONNX model. data: safetensors bytes to load into ONNX model. .. versionchanged:: 1.0 The return value is now the updated ONNX model instead of a set of loaded tensor names. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model # TODO: Handle more dtypes tensors = safetensors.deserialize(data) tensors_dict = { name: _tensors.ByteArrayTensor( data=metadata["data"], dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]], shape=ir.Shape(metadata["shape"]), name=name, ) for (name, metadata) in tensors } values = [value for value, _ in _get_value_tensor_pairs(model_ir)] _apply_tensors(values, tensors_dict) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
[docs] def load_file_as_external_data( model: TModel, /, location: str | os.PathLike, base_dir: str | os.PathLike = "" ) -> TModel: """Load weights from safetensors file and use them as external data for the ONNX model. .. versionadded:: 1.0 Added the function. Args: model: ONNX model or graph to load external data into. location: Path to the safetensors file relative to the ONNX model file. base_dir: Directory where the ONNX model file is stored. Returns: The ONNX model with the external data. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model replace_tensors(model_ir, location, base_dir) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
def _get_tensor_storage_shape(tensor: ir.TensorProtocol) -> Sequence[int]: """Get the storage shape of a tensor for safetensors.""" # Handle sub-byte dtypes if tensor.dtype.bitwidth < _BYTE_SIZE: return [tensor.nbytes] return tensor.shape.numpy()
[docs] def save(model: TModel, /, *, size_threshold: int = 0) -> bytes: """Save all tensors in an ONNX model to a safetensors object serialized as bytes. Args: model: ONNX model to save. size_threshold: Minimum size in bytes for a tensor to be saved. Default is 0, which saves all initializers. Returns: The safetensors object serialized as bytes. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model tensor_dict: dict[str, dict[str, Any]] = {} for name, initializer in model_ir.graph.initializers.items(): if initializer.const_value is None: continue if initializer.const_value.size < size_threshold: continue tensor = initializer.const_value tensor_dict[name] = { "dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype], "shape": _get_tensor_storage_shape(tensor), # TODO: Return a memoryview when safetensors supports it. "data": tensor.tobytes(), } return safetensors.serialize(tensor_dict)
def _get_value_tensor_pairs( model: ir.Model, ) -> list[tuple[ir.Value, ir.TensorProtocol]]: # Store the original initializer values so they can be restored if modify_model=False value_tensor_pairs: list[tuple[ir.Value, ir.TensorProtocol]] = [] initializer_names: set[str] = set() for graph in model.graphs(): for value in graph.initializers.values(): tensor = value.const_value # The value.name should be the same as tensor.name. However, # in case there is a conflict, we do not care and will prefer value.name. name = value.name if name is None: raise ValueError( f"Initializer value '{value!r}' has no name (in graph {graph.name!r}). " "All initializers must have names." ) if tensor is None: continue if name in initializer_names: raise ValueError( f"Duplicate initializer name found: {name} (in graph {graph.name!r})." " Rename the initializers to have unique names before saving to safetensors." ) initializer_names.add(name) value_tensor_pairs.append((value, tensor)) return value_tensor_pairs
[docs] def save_file( # noqa: PLR0912 model: TModel, /, location: str | os.PathLike, base_dir: str | os.PathLike = "", *, size_threshold: int = 0, replace_data: bool = True, max_shard_size: int | str | None = None, ) -> TModel: """Save all tensors in an ONNX model to a safetensors file. .. versionadded:: 1.0.0 The *replace_data* parameter was added to allow the user to choose whether to replace the data in the ONNX model with the external data. .. versionremoved:: 1.0.0 The *convert_attributes* and *strip_data* parameters were removed. Set *replace_data* to achieve similar effect as *strip_data*. .. versionchanged:: 1.0.0 The return value is now the updated ONNX model instead of a set of saved tensor names. .. versionadded:: 1.0.1 The *base_dir* parameter was added so the external data can be referenced relative to the ONNX model file correctly. .. versionadded:: 1.3.0 The *max_shard_size* parameter was added to support sharding large models. Args: model: ONNX model proto to save. location: Path to the safetensors file relative to the ONNX model file. base_dir: Directory where the ONNX model file is stored. size_threshold: Minimum size in bytes for a tensor to be saved. Default is 0, which saves all tensors. replace_data: Whether to replace the data in the ONNX model with the external data. Default is True. max_shard_size: Maximum size in bytes (as int) or as a string with unit (like "5GB" or "100MB") for a checkpoint before being sharded. If None, no sharding is performed. Returns: The ONNX model with the external data. """ # Ensure that external_data ends with .safetensors if not str(location).endswith(".safetensors"): raise ValueError( f'The path to safetensors file must have a .safetensors extension, got: "{location}"' ) max_shard_size_bytes = ( _parse_size_string(max_shard_size) if max_shard_size is not None else None ) size_threshold_bytes = size_threshold if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model initialized_values = [value for value, _ in _get_value_tensor_pairs(model_ir)] # First, collect metadata without loading tensor data tensors_to_save: list[ir.TensorProtocol] = [] values_to_save: list[ir.Value] = [] for value in initialized_values: tensor = value.const_value assert tensor is not None if tensor.nbytes < size_threshold_bytes: continue tensors_to_save.append(tensor) values_to_save.append(value) total_size = sum(tensor.nbytes for tensor in tensors_to_save) if tensors_to_save: # Determine sharding based on max_shard_size_bytes. When max_shard_size_bytes is None, # It is the same as one shard (which is the same as no sharding). tensor_shards = _shard_tensors(tensors_to_save, max_shard_size_bytes) total_shards = len(tensor_shards) # Save each shard, loading only necessary tensor data all_filenames = [] weight_map: dict[str, str] = {} # Maps tensor name to shard filename for shard_idx, tensor_shard in enumerate(tensor_shards, start=1): shard_filename = _get_shard_filename(str(location), shard_idx, total_shards) shard_path = os.path.join(base_dir, shard_filename) all_filenames.append(shard_filename) # Build tensor_dict for this shard only shard_dict: dict[str, Any] = {} for tensor in (pbar := tqdm(tensor_shard)): assert tensor.name is not None pbar.set_description(f"Saving {shard_filename} ({tensor.name})") shard_dict[tensor.name] = { "dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype], "shape": _get_tensor_storage_shape(tensor), "data": tensor.tobytes(), } # Update weight_map with shard filename weight_map[tensor.name] = shard_filename safetensors.serialize_file(shard_dict, shard_path) # Save index file if sharding occurred if total_shards > 1: location_str = str(location) if location_str.endswith(".safetensors"): index_filename = ( location_str.rsplit(".safetensors", 1)[0] + ".safetensors.index.json" ) else: index_filename = location_str + ".index.json" index_path = os.path.join(base_dir, index_filename) index_data = { "metadata": {"total_size": total_size}, "weight_map": weight_map, } with open(index_path, "w") as f: json.dump(index_data, f, indent=2) # Replace tensors from each shard file if replace_data: for filename in all_filenames: replace_tensors(model_ir, filename, base_dir) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
[docs] def save_model( model: TModel, model_path: str | os.PathLike, /, *, external_data: str | os.PathLike | None = None, size_threshold: int = 0, max_shard_size: int | str | None = None, ) -> None: """Save an ONNX model to a file with external data in a safetensors file. .. versionadded:: 1.3.0 Added the function. Args: model: ONNX model to save. model_path: Path to the ONNX model file. E.g. "model.onnx". external_data: Path to the safetensors file relative to the ONNX model file. E.g. "model.safetensors". If not provided, it will be derived from the model_path by replacing the extension with ".safetensors". size_threshold: Minimum size in bytes for a tensor to be saved. Default is 0, which saves all tensors. max_shard_size: Maximum size in bytes for a checkpoint before being sharded. If expressed as a string, needs to be digits followed by a unit (like "5GB" or "100MB"). If None, no sharding is performed. When sharding is enabled, multiple safetensors files will be created with names like "model-00001-of-00003.safetensors", and an index file "model.safetensors.index.json" will be created to map tensors to their respective shard files. Raises: ValueError: If external_data does not end with ".safetensors". """ # Derive external_data from model_path if not provided if external_data is None: model_path_str = str(model_path) # Get the base name without extension if "." in os.path.basename(model_path_str): base_name = os.path.splitext(os.path.basename(model_path_str))[0] else: base_name = os.path.basename(model_path_str) external_data = f"{base_name}.safetensors" if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model # Store the original initializer values so they can be restored if modify_model=False value_tensor_pairs = _get_value_tensor_pairs(model_ir) try: save_file( model_ir, external_data, os.path.dirname(model_path), size_threshold=size_threshold, max_shard_size=max_shard_size, ) ir.save(model_ir, model_path) finally: # Restore original initializers to avoid side effects for value, tensor in value_tensor_pairs: value.const_value = tensor
def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]], int]: """Read the header of a safetensors file. Args: file: The safetensors file to read. Returns: The header of the safetensors file. """ file.seek(0) header_size = struct.unpack_from("<Q", file.read(_HEADER_SIZE_NUMBER_SIZE))[0] header = file.read(header_size) return json.loads(header.decode("utf-8")), header_size def _read_safetensors( location: str | os.PathLike, base_dir: str | os.PathLike ) -> dict[str, ir.ExternalTensor]: """Read a safetensors file. Args: location: The safetensors file to read. base_dir: Directory where the ONNX model file is stored. Returns: The contents of the safetensors file. """ path = os.path.join(base_dir, location) with open(path, "rb") as file: header, header_size = _read_safetensors_header(file) tensors = {} for name, metadata in header.items(): if name == "__metadata__": continue offset = metadata["data_offsets"][0] + header_size + _HEADER_SIZE_NUMBER_SIZE length = metadata["data_offsets"][1] - metadata["data_offsets"][0] tensors[name] = ir.ExternalTensor( location=location, offset=offset, length=length, dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]], shape=ir.Shape(metadata["shape"]), name=name, base_dir=base_dir, ) return tensors def _check_tensors_match( model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor ): """Check if two tensors match. Args: model_tensor: Tensor from the model. safe_tensor: Tensor from the safetensors file. Raises: ValueError: If the tensors do not match. """ if model_tensor.nbytes != safe_tensor.nbytes: raise ValueError( f"Tensor size mismatch for tensor '{model_tensor.name}': " f"model tensor size {model_tensor.nbytes} bytes, " f"safetensors tensor size {safe_tensor.nbytes} bytes. " f"Model tensor: {model_tensor}, Safetensors tensor: {safe_tensor}" ) def _migrate_tensor_shape_dtype( model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor ) -> ir.ExternalTensor: """Migrate the shape and dtype of a tensor. This is needed because we store 4bit and 2bit tensors as UINT8 in safetensors. Args: model_tensor: The tensor to migrate. safe_tensor: The tensor to migrate to. Returns: The migrated tensor. """ if model_tensor.dtype in { # Types that safetensors does not support directly ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2FNUZ, ir.DataType.FLOAT4E2M1, # Still need to migrate shape ir.DataType.INT4, ir.DataType.INT2, ir.DataType.UINT4, ir.DataType.UINT2, }: return ir.ExternalTensor( location=safe_tensor.location, offset=safe_tensor.offset, length=safe_tensor.length, dtype=model_tensor.dtype, shape=model_tensor.shape, # type: ignore[arg-type] name=safe_tensor.name, base_dir=safe_tensor.base_dir, ) return safe_tensor