Source code for onnx_safetensors._safetensors_io

"""Private module for loading and saving safetensors data to ONNX models."""

from __future__ import annotations

import io
import json
import os
import re
import struct
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeVar

import google.protobuf.json_format
import onnx
import onnx_ir as ir
import safetensors
from tqdm.auto import tqdm

from onnx_safetensors import _tensors

if TYPE_CHECKING:
    pass


_HEADER_SIZE_NUMBER_SIZE = 8
_BYTE_SIZE = 8
# https://github.com/huggingface/safetensors/blob/543243c3017e413584f27ebd4b99c844f62deb34/safetensors/src/tensor.rs#L664
_SAFETENSORS_DTYPE_TO_IR_DTYPE = {
    "BOOL": ir.DataType.BOOL,
    "F4": ir.DataType.FLOAT4E2M1,
    "F8_E5M2": ir.DataType.FLOAT8E5M2,
    "F8_E4M3": ir.DataType.FLOAT8E4M3FN,
    "F8_E8M0": ir.DataType.FLOAT8E8M0,
    "BF16": ir.DataType.BFLOAT16,
    "F16": ir.DataType.FLOAT16,
    "F32": ir.DataType.FLOAT,
    "F64": ir.DataType.DOUBLE,
    "I8": ir.DataType.INT8,
    "I16": ir.DataType.INT16,
    "I32": ir.DataType.INT32,
    "I64": ir.DataType.INT64,
    "U8": ir.DataType.UINT8,
    "U16": ir.DataType.UINT16,
    "U32": ir.DataType.UINT32,
    "U64": ir.DataType.UINT64,
    "C64": ir.DataType.COMPLEX64,
}
_IR_DTYPE_TO_SAFETENSORS_DTYPE = {
    ir.DataType.BOOL: "bool",
    ir.DataType.FLOAT4E2M1: "float4_e2m1fn_x2",
    ir.DataType.FLOAT8E5M2: "float8_e5m2",
    ir.DataType.FLOAT8E4M3FN: "float8_e4m3fn",
    ir.DataType.FLOAT8E8M0: "float8_e8m0",
    ir.DataType.FLOAT8E4M3FNUZ: "uint8",
    ir.DataType.FLOAT8E5M2FNUZ: "uint8",
    ir.DataType.BFLOAT16: "bfloat16",
    ir.DataType.FLOAT16: "float16",
    ir.DataType.FLOAT: "float32",
    ir.DataType.DOUBLE: "float64",
    ir.DataType.INT2: "uint8",
    ir.DataType.INT4: "uint8",
    ir.DataType.INT8: "int8",
    ir.DataType.INT16: "int16",
    ir.DataType.INT32: "int32",
    ir.DataType.INT64: "int64",
    ir.DataType.UINT2: "uint8",
    ir.DataType.UINT4: "uint8",
    ir.DataType.UINT8: "uint8",
    ir.DataType.UINT16: "uint16",
    ir.DataType.UINT32: "uint32",
    ir.DataType.UINT64: "uint64",
    ir.DataType.COMPLEX64: "complex64",
}


TModel = TypeVar("TModel", onnx.ModelProto, ir.Model)


def _parse_size_string(size: int | str) -> int:
    """Parse a size string like '5GB' or '100MB' into bytes.

    Args:
        size: Either an integer representing bytes, or a string like '5GB', '100MB', etc.

    Returns:
        The size in bytes.

    Raises:
        ValueError: If the size string format is invalid.
    """
    if isinstance(size, int):
        return size

    size = size.strip()
    match = re.match(r"(\d+(?:\.\d+)?)\s*([A-Za-z]+)", size)
    if not match:
        raise ValueError(
            f"Invalid size format: {size}. Expected format like '5GB' or '100MB'."
        )

    num_str, unit = match.groups()
    num = float(num_str)

    # Convert to bytes
    unit = unit.upper()
    multipliers = {
        "B": 1,
        "KB": 1000,
        "MB": 1000**2,
        "GB": 1000**3,
        "TB": 1000**4,
    }

    if unit not in multipliers:
        raise ValueError(
            f"Unknown size unit: {unit}. Valid units are: {', '.join(multipliers.keys())}"
        )

    return int(num * multipliers[unit])


def _get_shard_filename(base_name: str, shard_idx: int, total_shards: int) -> str:
    """Generate a filename for a shard.

    Args:
        base_name: The base filename (e.g., 'model.safetensors').
        shard_idx: The index of this shard (1-indexed).
        total_shards: The total number of shards.

    Returns:
        The shard filename (e.g., 'model-00001-of-00003.safetensors').
    """
    if total_shards == 1:
        return base_name

    # Extract extension
    if "." in base_name:
        name, ext = base_name.rsplit(".", 1)
        ext = f".{ext}"
    else:
        name = base_name
        ext = ""

    # Always use 5 digits to follow transformers convention
    return f"{name}-{shard_idx:05d}-of-{total_shards:05d}{ext}"


def _shard_tensors(
    tensors: Sequence[ir.TensorProtocol], max_shard_size_bytes: int | None
) -> list[list[ir.TensorProtocol]]:
    """Shard tensors into multiple files based on max_shard_size_bytes.

    Args:
        tensors: The tensors to shard.
        max_shard_size_bytes: Maximum size for each shard in bytes. When None,
            no sharding is performed.

    Returns:
        A list of tensor lists for each shard.
    """
    if max_shard_size_bytes is None:
        # No sharding
        return [list(tensors)]

    # Shard the tensors by current order
    shards: list[list[ir.TensorProtocol]] = [[]]
    current_shard_size = 0

    for tensor in tensors:
        tensor_size = tensor.nbytes
        # Check if adding this tensor would exceed max_shard_size_bytes
        if (
            current_shard_size + tensor_size > max_shard_size_bytes
            and current_shard_size > 0
        ):
            # Start a new shard
            shards.append([])
            current_shard_size = 0

        shards[-1].append(tensor)
        current_shard_size += tensor_size

    return shards


def _apply_tensors(
    values: Sequence[ir.Value],
    tensors: Mapping[str, ir.TensorProtocol],
    apply_safetensors: bool = False,
):
    """Apply tensors to an ONNX model.

    Args:
        values: Values in the ONNX model.
        tensors: Tensors to apply to the ONNX model.
        apply_safetensors: Whether it is applying safetensors to the ONNX model.
    """
    value_map: dict[str, ir.Value] = {value.name: value for value in values}  # type: ignore[misc]
    for name, tensor in tensors.items():
        if name not in value_map:
            continue
        value = value_map[name]
        model_tensor = value_map[name].const_value
        if model_tensor is not None and apply_safetensors:
            assert isinstance(tensor, ir.ExternalTensor)
            _check_tensors_match(model_tensor, tensor)
            updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor)
        else:
            updated_tensor = tensor
        value.const_value = updated_tensor


[docs] def replace_tensors( model: ir.Model, /, location: str | os.PathLike, base_dir: str | os.PathLike ) -> None: """Replace weights in an ONNX model with those of the same names found in the safetensors file. The weights will be represented as 'external data' references in onnx TensorProto. .. versionadded:: 1.0 Added the function. Args: model: ONNX model to replace tensors in. location: Path to the safetensors file relative to the ONNX model file. base_dir: Directory where the ONNX model file is stored. """ tensors = _read_safetensors(location, base_dir=base_dir) values = [value for value, _ in _get_value_tensor_pairs(model)] _apply_tensors(values, tensors, apply_safetensors=True)
[docs] def load_file(model: TModel, /, tensor_file: str | os.PathLike) -> TModel: """Load external data into ONNX model from a safetensors file. Args: model: ONNX model. tensor_file: safetensors file to load into ONNX model. .. versionchanged:: 1.0 The return value is now the updated ONNX model instead of a set of loaded tensor names. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model replace_tensors(model_ir, tensor_file, "") model_ir = ir.external_data.load_to_model(model_ir) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
[docs] def load(model: TModel, /, data: bytes) -> TModel: """Load external data into ONNX model from safetensors bytes. Args: model: ONNX model. data: safetensors bytes to load into ONNX model. .. versionchanged:: 1.0 The return value is now the updated ONNX model instead of a set of loaded tensor names. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model # TODO: Handle more dtypes tensors = safetensors.deserialize(data) tensors_dict = { name: _tensors.ByteArrayTensor( data=metadata["data"], dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]], shape=ir.Shape(metadata["shape"]), name=name, ) for (name, metadata) in tensors } values = [value for value, _ in _get_value_tensor_pairs(model_ir)] _apply_tensors(values, tensors_dict) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
[docs] def load_file_as_external_data( model: TModel, /, location: str | os.PathLike, base_dir: str | os.PathLike = "" ) -> TModel: """Load weights from safetensors file and use them as external data for the ONNX model. .. versionadded:: 1.0 Added the function. Args: model: ONNX model or graph to load external data into. location: Path to the safetensors file relative to the ONNX model file. base_dir: Directory where the ONNX model file is stored. Returns: The ONNX model with the external data. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model replace_tensors(model_ir, location, base_dir) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
def _get_tensor_storage_shape(tensor: ir.TensorProtocol) -> Sequence[int]: """Get the storage shape of a tensor for safetensors.""" # Handle sub-byte dtypes if tensor.dtype.bitwidth < _BYTE_SIZE: return [tensor.nbytes] return tensor.shape.numpy()
[docs] def save(model: TModel, /, *, size_threshold: int = 0) -> bytes: """Save all tensors in an ONNX model to a safetensors object serialized as bytes. Args: model: ONNX model to save. size_threshold: Minimum size in bytes for a tensor to be saved. Default is 0, which saves all initializers. Returns: The safetensors object serialized as bytes. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model tensor_dict: dict[str, dict[str, Any]] = {} for name, initializer in model_ir.graph.initializers.items(): if initializer.const_value is None: continue if initializer.const_value.size < size_threshold: continue tensor = initializer.const_value tensor_dict[name] = { "dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype], "shape": _get_tensor_storage_shape(tensor), # TODO: Return a memoryview when safetensors supports it. "data": tensor.tobytes(), } return safetensors.serialize(tensor_dict)
def _get_value_tensor_pairs( model: ir.Model, ) -> list[tuple[ir.Value, ir.TensorProtocol]]: # Store the original initializer values so they can be restored if modify_model=False value_tensor_pairs: list[tuple[ir.Value, ir.TensorProtocol]] = [] initializer_names: set[str] = set() for graph in model.graphs(): for value in graph.initializers.values(): tensor = value.const_value # The value.name should be the same as tensor.name. However, # in case there is a conflict, we do not care and will prefer value.name. name = value.name if name is None: raise ValueError( f"Initializer value '{value!r}' has no name (in graph {graph.name!r}). " "All initializers must have names." ) if tensor is None: continue if name in initializer_names: raise ValueError( f"Duplicate initializer name found: {name} (in graph {graph.name!r})." " Rename the initializers to have unique names before saving to safetensors." ) initializer_names.add(name) value_tensor_pairs.append((value, tensor)) return value_tensor_pairs
[docs] def save_file( # noqa: PLR0912 model: TModel, /, location: str | os.PathLike, base_dir: str | os.PathLike = "", *, size_threshold: int = 0, replace_data: bool = True, max_shard_size: int | str | None = None, ) -> TModel: """Save all tensors in an ONNX model to a safetensors file. .. versionadded:: 1.0.0 The *replace_data* parameter was added to allow the user to choose whether to replace the data in the ONNX model with the external data. .. versionremoved:: 1.0.0 The *convert_attributes* and *strip_data* parameters were removed. Set *replace_data* to achieve similar effect as *strip_data*. .. versionchanged:: 1.0.0 The return value is now the updated ONNX model instead of a set of saved tensor names. .. versionadded:: 1.0.1 The *base_dir* parameter was added so the external data can be referenced relative to the ONNX model file correctly. .. versionadded:: 1.3.0 The *max_shard_size* parameter was added to support sharding large models. Args: model: ONNX model proto to save. location: Path to the safetensors file relative to the ONNX model file. base_dir: Directory where the ONNX model file is stored. size_threshold: Minimum size in bytes for a tensor to be saved. Default is 0, which saves all tensors. replace_data: Whether to replace the data in the ONNX model with the external data. Default is True. max_shard_size: Maximum size in bytes (as int) or as a string with unit (like "5GB" or "100MB") for a checkpoint before being sharded. If None, no sharding is performed. Returns: The ONNX model with the external data. """ # Ensure that external_data ends with .safetensors if not str(location).endswith(".safetensors"): raise ValueError( f'The path to safetensors file must have a .safetensors extension, got: "{location}"' ) max_shard_size_bytes = ( _parse_size_string(max_shard_size) if max_shard_size is not None else None ) size_threshold_bytes = size_threshold if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model initialized_values = [value for value, _ in _get_value_tensor_pairs(model_ir)] # First, collect metadata without loading tensor data tensors_to_save: list[ir.TensorProtocol] = [] values_to_save: list[ir.Value] = [] for value in initialized_values: tensor = value.const_value assert tensor is not None if tensor.nbytes < size_threshold_bytes: continue tensors_to_save.append(tensor) values_to_save.append(value) total_size = sum(tensor.nbytes for tensor in tensors_to_save) if tensors_to_save: # Determine sharding based on max_shard_size_bytes. When max_shard_size_bytes is None, # It is the same as one shard (which is the same as no sharding). tensor_shards = _shard_tensors(tensors_to_save, max_shard_size_bytes) total_shards = len(tensor_shards) # Save each shard, loading only necessary tensor data all_filenames = [] weight_map: dict[str, str] = {} # Maps tensor name to shard filename for shard_idx, tensor_shard in enumerate(tensor_shards, start=1): shard_filename = _get_shard_filename(str(location), shard_idx, total_shards) shard_path = os.path.join(base_dir, shard_filename) all_filenames.append(shard_filename) # Build tensor_dict for this shard only shard_dict: dict[str, Any] = {} for tensor in (pbar := tqdm(tensor_shard)): assert tensor.name is not None pbar.set_description(f"Saving {shard_filename} ({tensor.name})") shard_dict[tensor.name] = { "dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype], "shape": _get_tensor_storage_shape(tensor), "data": tensor.tobytes(), } # Update weight_map with shard filename weight_map[tensor.name] = shard_filename safetensors.serialize_file(shard_dict, shard_path) # Save index file if sharding occurred if total_shards > 1: location_str = str(location) if location_str.endswith(".safetensors"): index_filename = ( location_str.rsplit(".safetensors", 1)[0] + ".safetensors.index.json" ) else: index_filename = location_str + ".index.json" index_path = os.path.join(base_dir, index_filename) index_data = { "metadata": {"total_size": total_size}, "weight_map": weight_map, } with open(index_path, "w") as f: json.dump(index_data, f, indent=2) # Replace tensors from each shard file if replace_data: for filename in all_filenames: replace_tensors(model_ir, filename, base_dir) if isinstance(model, onnx.ModelProto): return ir.serde.serialize_model(model_ir) return model_ir
[docs] def save_model( model: TModel, model_path: str | os.PathLike, /, *, external_data: str | os.PathLike | None = None, size_threshold: int = 0, max_shard_size: int | str | None = None, ) -> None: """Save an ONNX model to a file with external data in a safetensors file. .. versionadded:: 1.3.0 Added the function. Args: model: ONNX model to save. model_path: Path to the ONNX model file. E.g. "model.onnx". external_data: Path to the safetensors file relative to the ONNX model file. E.g. "model.safetensors". If not provided, it will be derived from the model_path by replacing the extension with ".safetensors". size_threshold: Minimum size in bytes for a tensor to be saved. Default is 0, which saves all tensors. max_shard_size: Maximum size in bytes for a checkpoint before being sharded. If expressed as a string, needs to be digits followed by a unit (like "5GB" or "100MB"). If None, no sharding is performed. When sharding is enabled, multiple safetensors files will be created with names like "model-00001-of-00003.safetensors", and an index file "model.safetensors.index.json" will be created to map tensors to their respective shard files. Raises: ValueError: If external_data does not end with ".safetensors". """ # Derive external_data from model_path if not provided if external_data is None: model_path_str = str(model_path) # Get the base name without extension if "." in os.path.basename(model_path_str): base_name = os.path.splitext(os.path.basename(model_path_str))[0] else: base_name = os.path.basename(model_path_str) external_data = f"{base_name}.safetensors" if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model # Store the original initializer values so they can be restored if modify_model=False value_tensor_pairs = _get_value_tensor_pairs(model_ir) try: save_file( model_ir, external_data, os.path.dirname(model_path), size_threshold=size_threshold, max_shard_size=max_shard_size, ) ir.save(model_ir, model_path) finally: # Restore original initializers to avoid side effects for value, tensor in value_tensor_pairs: value.const_value = tensor
[docs] def save_safetensors_model( model: TModel, safetensors_model_path: str | os.PathLike, ): """Save an ONNX model to a safetensors file. This function will embed the onnx model **into** a safetensors file. This file is not compatible with onnx runtime, and is only useful for storage or transfer. .. versionadded:: 1.5.0 Added the function. Args: model: ONNX model to save. safetensors_model_path: Path to the safetensors model file. """ if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) else: model_ir = model value_tensor_pairs = _get_value_tensor_pairs(model_ir) tensor_dict: dict[str, dict[str, Any]] = {} for value, tensor in value_tensor_pairs: name = value.name assert name is not None tensor_dict[name] = { "dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype], "shape": _get_tensor_storage_shape(tensor), "data": tensor.tobytes(), } try: for value, tensor in value_tensor_pairs: # There is no way to determine the offset and length before writing the # safetensors file. This is a chicken-and-egg problem. value.const_value = ir.ExternalTensor( location=".", offset=-1, length=tensor.nbytes, dtype=tensor.dtype, shape=tensor.shape, name=value.name, base_dir="", ) finally: # Restore original initializers to avoid side effects for value, tensor in value_tensor_pairs: value.const_value = tensor onnx_json_text = google.protobuf.json_format.MessageToJson( ir.serde.serialize_model(model_ir), preserving_proto_field_name=True, indent=None, ) metadata = {"onnx": onnx_json_text} safetensors.serialize_file(tensor_dict, safetensors_model_path, metadata=metadata)
[docs] def extract_safetensors_model( safetensors_model_path: str | os.PathLike, output_path: str | os.PathLike | None = None, ) -> ir.Model: """Unpack an ONNX model stored in a safetensors file. This function will extract the onnx model **from** a safetensors file created by `save_safetensors_model`. .. versionadded:: 1.5.0 Added the function. Args: safetensors_model_path: Path to the safetensors model file. output_path: If provided, the extracted ONNX model (that references the safetensors file as external data) will be saved to this path. Returns: The extracted ONNX model. """ metadata = _read_metadata(safetensors_model_path) if "onnx" not in metadata: raise ValueError( f"The safetensors file '{safetensors_model_path}' does not contain an ONNX model." ) onnx_json = metadata["onnx"] proto = google.protobuf.json_format.Parse(onnx_json, onnx.ModelProto()) model = ir.serde.deserialize_model(proto) filename = os.path.basename(safetensors_model_path) base_dir = os.path.dirname(safetensors_model_path) replace_tensors(model, filename, base_dir=base_dir) if output_path is not None: ir.save(model, output_path) return model
def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]], int]: """Read the header of a safetensors file. Args: file: The safetensors file to read. Returns: The header of the safetensors file. """ file.seek(0) header_size = struct.unpack_from("<Q", file.read(_HEADER_SIZE_NUMBER_SIZE))[0] header = file.read(header_size) return json.loads(header.decode("utf-8")), header_size def _read_safetensors( location: str | os.PathLike, base_dir: str | os.PathLike ) -> dict[str, ir.ExternalTensor]: """Read a safetensors file. Args: location: The safetensors file to read. base_dir: Directory where the ONNX model file is stored. Returns: The contents of the safetensors file. """ path = os.path.join(base_dir, location) with open(path, "rb") as file: header, header_size = _read_safetensors_header(file) tensors = {} for name, metadata in header.items(): if name == "__metadata__": continue offset = metadata["data_offsets"][0] + header_size + _HEADER_SIZE_NUMBER_SIZE length = metadata["data_offsets"][1] - metadata["data_offsets"][0] tensors[name] = ir.ExternalTensor( location=location, offset=offset, length=length, dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]], shape=ir.Shape(metadata["shape"]), name=name, base_dir=base_dir, ) return tensors def _read_metadata(path: str | os.PathLike) -> dict[str, Any]: """Read the metadata of a safetensors file. Args: path: The safetensors file to read. Returns: The metadata of the safetensors file. """ with open(path, "rb") as file: header, _ = _read_safetensors_header(file) return header.get("__metadata__", {}) def _check_tensors_match( model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor ): """Check if two tensors match. Args: model_tensor: Tensor from the model. safe_tensor: Tensor from the safetensors file. Raises: ValueError: If the tensors do not match. """ if model_tensor.nbytes != safe_tensor.nbytes: raise ValueError( f"Tensor size mismatch for tensor '{model_tensor.name}': " f"model tensor size {model_tensor.nbytes} bytes, " f"safetensors tensor size {safe_tensor.nbytes} bytes. " f"Model tensor: {model_tensor}, Safetensors tensor: {safe_tensor}" ) def _migrate_tensor_shape_dtype( model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor ) -> ir.ExternalTensor: """Migrate the shape and dtype of a tensor. This is needed because we store 4bit and 2bit tensors as UINT8 in safetensors. Args: model_tensor: The tensor to migrate. safe_tensor: The tensor to migrate to. Returns: The migrated tensor. """ if model_tensor.dtype in { # Types that safetensors does not support directly ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2FNUZ, ir.DataType.FLOAT4E2M1, # Still need to migrate shape ir.DataType.INT4, ir.DataType.INT2, ir.DataType.UINT4, ir.DataType.UINT2, }: return ir.ExternalTensor( location=safe_tensor.location, offset=safe_tensor.offset, length=safe_tensor.length, dtype=model_tensor.dtype, shape=model_tensor.shape, # type: ignore[arg-type] name=safe_tensor.name, base_dir=safe_tensor.base_dir, ) return safe_tensor