"""Private module for loading and saving safetensors data to ONNX models."""
from __future__ import annotations
import io
import json
import math
import os
import re
import struct
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, TypeVar
import onnx
import onnx_ir as ir
import safetensors
from tqdm.auto import tqdm
from onnx_safetensors import _tensors
if TYPE_CHECKING:
pass
_HEADER_SIZE_NUMBER_SIZE = 8
# https://github.com/huggingface/safetensors/blob/543243c3017e413584f27ebd4b99c844f62deb34/safetensors/src/tensor.rs#L664
_SAFETENSORS_DTYPE_TO_IR_DTYPE = {
"BOOL": ir.DataType.BOOL,
"F8_E5M2": ir.DataType.FLOAT8E5M2,
"F8_E4M3": ir.DataType.FLOAT8E4M3FN,
"BF16": ir.DataType.BFLOAT16,
"F16": ir.DataType.FLOAT16,
"F32": ir.DataType.FLOAT,
"F64": ir.DataType.DOUBLE,
"I8": ir.DataType.INT8,
"I16": ir.DataType.INT16,
"I32": ir.DataType.INT32,
"I64": ir.DataType.INT64,
"U8": ir.DataType.UINT8,
"U16": ir.DataType.UINT16,
"U32": ir.DataType.UINT32,
"U64": ir.DataType.UINT64,
}
_IR_DTYPE_TO_SAFETENSORS_DTYPE = {
ir.DataType.BOOL: "bool",
ir.DataType.FLOAT4E2M1: "uint8",
ir.DataType.FLOAT8E5M2: "float8_e5m2",
ir.DataType.FLOAT8E4M3FN: "float8_e4m3fn",
ir.DataType.FLOAT8E4M3FNUZ: "uint8",
ir.DataType.FLOAT8E5M2FNUZ: "uint8",
ir.DataType.BFLOAT16: "bfloat16",
ir.DataType.FLOAT16: "float16",
ir.DataType.FLOAT: "float32",
ir.DataType.DOUBLE: "float64",
ir.DataType.INT4: "uint8",
ir.DataType.INT8: "int8",
ir.DataType.INT16: "int16",
ir.DataType.INT32: "int32",
ir.DataType.INT64: "int64",
ir.DataType.UINT4: "uint8",
ir.DataType.UINT8: "uint8",
ir.DataType.UINT16: "uint16",
ir.DataType.UINT32: "uint32",
ir.DataType.UINT64: "uint64",
}
TModel = TypeVar("TModel", onnx.ModelProto, ir.Model)
def _parse_size_string(size: int | str) -> int:
"""Parse a size string like '5GB' or '100MB' into bytes.
Args:
size: Either an integer representing bytes, or a string like '5GB', '100MB', etc.
Returns:
The size in bytes.
Raises:
ValueError: If the size string format is invalid.
"""
if isinstance(size, int):
return size
size = size.strip()
match = re.match(r"(\d+(?:\.\d+)?)\s*([A-Za-z]+)", size)
if not match:
raise ValueError(
f"Invalid size format: {size}. Expected format like '5GB' or '100MB'."
)
num_str, unit = match.groups()
num = float(num_str)
# Convert to bytes
unit = unit.upper()
multipliers = {
"B": 1,
"KB": 1000,
"MB": 1000**2,
"GB": 1000**3,
"TB": 1000**4,
}
if unit not in multipliers:
raise ValueError(
f"Unknown size unit: {unit}. Valid units are: {', '.join(multipliers.keys())}"
)
return int(num * multipliers[unit])
def _get_shard_filename(base_name: str, shard_idx: int, total_shards: int) -> str:
"""Generate a filename for a shard.
Args:
base_name: The base filename (e.g., 'model.safetensors').
shard_idx: The index of this shard (1-indexed).
total_shards: The total number of shards.
Returns:
The shard filename (e.g., 'model-00001-of-00003.safetensors').
"""
if total_shards == 1:
return base_name
# Extract extension
if "." in base_name:
name, ext = base_name.rsplit(".", 1)
ext = f".{ext}"
else:
name = base_name
ext = ""
# Always use 5 digits to follow transformers convention
return f"{name}-{shard_idx:05d}-of-{total_shards:05d}{ext}"
def _shard_tensors(
tensor_metadata: dict[str, dict[str, Any]], max_shard_size: int | str
) -> list[list[str]]:
"""Shard tensors into multiple files based on max_shard_size.
Args:
tensor_metadata: Dictionary of tensor name to metadata (size, dtype, shape).
max_shard_size: Maximum size for each shard in bytes or as a string like '5GB'.
Returns:
A list of tensor name lists for each shard.
"""
max_size_bytes = _parse_size_string(max_shard_size)
# Shard the tensors by current order
shards: list[list[str]] = [[]]
current_shard_size = 0
for tensor_name, metadata in tensor_metadata.items():
tensor_size = metadata["size"]
# Check if adding this tensor would exceed max_shard_size
if current_shard_size + tensor_size > max_size_bytes and current_shard_size > 0:
# Start a new shard
shards.append([])
current_shard_size = 0
# Add tensor name to current shard
shards[-1].append(tensor_name)
current_shard_size += tensor_size
return shards
def _apply_tensors(
model: ir.Model,
tensors: Mapping[str, ir.TensorProtocol],
apply_safetensors: bool = False,
):
"""Apply tensors to an ONNX model.
Args:
model: ONNX model to apply tensors to.
tensors: Tensors to apply to the ONNX model.
apply_safetensors: Whether it is applying safetensors to the ONNX model.
"""
graph = model.graph
for name, tensor in tensors.items():
if name not in graph.initializers:
continue
model_tensor = graph.initializers[name].const_value
if model_tensor is not None and apply_safetensors:
assert isinstance(tensor, ir.ExternalTensor)
_check_tensors_match(model_tensor, tensor)
updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor)
else:
updated_tensor = tensor
graph.initializers[name].const_value = updated_tensor
def _is_4bit(dtype: ir.DataType) -> bool:
return dtype in {
ir.DataType.UINT4,
ir.DataType.INT4,
ir.DataType.FLOAT4E2M1,
}
def _is_8bit_float(dtype: ir.DataType) -> bool:
return dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2FNUZ,
}
[docs]
def replace_tensors(
model: ir.Model, /, location: str | os.PathLike, base_dir: str | os.PathLike
) -> None:
"""Replace all tensors in an ONNX model with external data from a safetensors file.
.. versionadded:: 1.0
Added the function.
Args:
model: ONNX model to replace tensors in.
location: Path to the safetensors file relative to the ONNX model file.
base_dir: Directory where the ONNX model file is stored.
"""
tensors = _read_safetensors(location, base_dir=base_dir)
_apply_tensors(model, tensors, apply_safetensors=True)
[docs]
def load_file(model: TModel, /, tensor_file: str | os.PathLike) -> TModel:
"""Load external data into ONNX model from a safetensors file.
Args:
model: ONNX model.
tensor_file: safetensors file to load into ONNX model.
.. versionchanged:: 1.0
The return value is now the updated ONNX model instead of a set of loaded tensor names.
"""
if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
else:
model_ir = model
replace_tensors(model_ir, tensor_file, "")
model_ir = ir.external_data.load_to_model(model_ir)
if isinstance(model, onnx.ModelProto):
return ir.serde.serialize_model(model_ir)
return model_ir
[docs]
def load(model: TModel, /, data: bytes) -> TModel:
"""Load external data into ONNX model from safetensors bytes.
Args:
model: ONNX model.
data: safetensors bytes to load into ONNX model.
.. versionchanged:: 1.0
The return value is now the updated ONNX model instead of a set of loaded tensor names.
"""
if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
else:
model_ir = model
# TODO: Handle more dtypes
tensors = safetensors.deserialize(data)
tensors_dict = {
name: _tensors.ByteArrayTensor(
data=metadata["data"],
dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]],
shape=ir.Shape(metadata["shape"]),
name=name,
)
for (name, metadata) in tensors
}
_apply_tensors(model_ir, tensors_dict)
if isinstance(model, onnx.ModelProto):
return ir.serde.serialize_model(model_ir)
return model_ir
[docs]
def load_file_as_external_data(
model: TModel, /, location: str | os.PathLike, base_dir: str | os.PathLike = ""
) -> TModel:
"""Load weights from safetensors file and use them as external data for the ONNX model.
.. versionadded:: 1.0
Added the function.
Args:
model: ONNX model or graph to load external data into.
location: Path to the safetensors file relative to the ONNX model file.
base_dir: Directory where the ONNX model file is stored.
Returns:
The ONNX model with the external data.
"""
if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
else:
model_ir = model
replace_tensors(model_ir, location, base_dir)
if isinstance(model, onnx.ModelProto):
return ir.serde.serialize_model(model_ir)
return model_ir
def _get_tensor_storage_shape(tensor: ir.TensorProtocol) -> list[int]:
if _is_4bit(tensor.dtype):
return [math.ceil(math.prod(tensor.shape.numpy()) / 2)]
return tensor.shape.numpy()
[docs]
def save(model: TModel, /, *, size_threshold: int = 0) -> bytes:
"""Save all tensors in an ONNX model to a safetensors object serialized as bytes.
Args:
model: ONNX model to save.
size_threshold: Minimum size in bytes for a tensor to be saved.
Default is 0, which saves all initializers.
Returns:
The safetensors object serialized as bytes.
"""
if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
else:
model_ir = model
tensor_dict: dict[str, dict[str, Any]] = {}
for name, initializer in model_ir.graph.initializers.items():
if initializer.const_value is None:
continue
if initializer.const_value.size < size_threshold:
continue
tensor = initializer.const_value
tensor_dict[name] = {
"dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype],
"shape": _get_tensor_storage_shape(tensor),
# TODO: Return a memoryview when safetensors supports it.
"data": tensor.tobytes(),
}
return safetensors.serialize(tensor_dict)
[docs]
def save_file( # noqa: PLR0912, PLR0915
model: TModel,
/,
location: str | os.PathLike,
base_dir: str | os.PathLike = "",
*,
size_threshold: int = 0,
replace_data: bool = True,
max_shard_size: int | str | None = None,
) -> TModel:
"""Save all tensors in an ONNX model to a safetensors file.
.. versionadded:: 1.0.0
The *replace_data* parameter was added to allow the user to choose
whether to replace the data in the ONNX model with the external data.
.. versionremoved:: 1.0.0
The *convert_attributes* and *strip_data* parameters were removed. Set
*replace_data* to achieve similar effect as *strip_data*.
.. versionchanged:: 1.0.0
The return value is now the updated ONNX model instead of a set of saved tensor names.
.. versionadded:: 1.0.1
The *base_dir* parameter was added so the external data can be referenced
relative to the ONNX model file correctly.
.. versionadded:: 1.3.0
The *max_shard_size* parameter was added to support sharding large models.
Args:
model: ONNX model proto to save.
location: Path to the safetensors file relative to the ONNX model file.
base_dir: Directory where the ONNX model file is stored.
size_threshold: Minimum size in bytes for a tensor to be saved.
Default is 0, which saves all tensors.
replace_data: Whether to replace the data in the ONNX model with
the external data. Default is True.
max_shard_size: Maximum size in bytes (as int) or as a string with unit
(like "5GB" or "100MB") for a checkpoint before being sharded.
If None, no sharding is performed.
Returns:
The ONNX model with the external data.
"""
# Ensure that external_data ends with .safetensors
if not str(location).endswith(".safetensors"):
raise ValueError(
f"The path to safetensors file must have a .safetensors extension, got: {location}"
)
if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
else:
model_ir = model
# Handle sharding if max_shard_size is specified
if max_shard_size is not None:
# First, collect metadata without loading tensor data
tensor_metadata = {}
total_size = 0
for name, initializer in model_ir.graph.initializers.items():
if initializer.const_value is None:
continue
if initializer.const_value.size < size_threshold:
continue
tensor = initializer.const_value
tensor_size = tensor.nbytes
tensor_metadata[name] = {
"size": tensor_size,
"dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype],
"shape": _get_tensor_storage_shape(tensor),
}
total_size += tensor_size
if tensor_metadata:
# Determine sharding based on metadata only
shard_tensor_names = _shard_tensors(tensor_metadata, max_shard_size)
total_shards = len(shard_tensor_names)
# Save each shard, loading only necessary tensor data
all_shards = []
weight_map: dict[str, str] = {} # Maps tensor name to shard filename
for shard_idx, tensor_names in enumerate(shard_tensor_names, start=1):
shard_filename = _get_shard_filename(
str(location), shard_idx, total_shards
)
# Build tensor_dict for this shard only
shard_dict = {}
for tensor_name in (pbar := tqdm(tensor_names)):
pbar.set_description(f"Saving {shard_filename} ({tensor_name})")
tensor = model_ir.graph.initializers[tensor_name].const_value
shard_dict[tensor_name] = {
"dtype": tensor_metadata[tensor_name]["dtype"],
"shape": tensor_metadata[tensor_name]["shape"],
"data": tensor.tobytes(),
}
shard_path = os.path.join(base_dir, shard_filename)
all_shards.append(shard_filename)
safetensors.serialize_file(shard_dict, shard_path)
# Update weight_map with shard filename
for tensor_name in tensor_names:
weight_map[tensor_name] = shard_filename
# Save index file if sharding occurred
if total_shards > 1:
location_str = str(location)
if location_str.endswith(".safetensors"):
index_filename = (
location_str.rsplit(".safetensors", 1)[0]
+ ".safetensors.index.json"
)
else:
index_filename = location_str + ".index.json"
index_path = os.path.join(base_dir, index_filename)
index_data = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
with open(index_path, "w") as f:
json.dump(index_data, f, indent=2)
# For replace_data, replace tensors from each shard
if replace_data:
# When sharded, replace tensors from each shard file
for file_name in all_shards:
replace_tensors(model_ir, file_name, base_dir)
else:
# No tensors to save
pass
else:
# No sharding - load all tensor data at once
tensor_dict = {}
for name, initializer in model_ir.graph.initializers.items():
if initializer.const_value is None:
continue
if initializer.const_value.size < size_threshold:
continue
tensor = initializer.const_value
tensor_dict[name] = {
"dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype],
"shape": _get_tensor_storage_shape(tensor),
# TODO: Return a memoryview when safetensors supports it.
"data": tensor.tobytes(),
}
if tensor_dict:
tensor_file = os.path.join(base_dir, location)
safetensors.serialize_file(tensor_dict, tensor_file)
if replace_data:
replace_tensors(model_ir, location, base_dir)
if isinstance(model, onnx.ModelProto):
return ir.serde.serialize_model(model_ir)
return model_ir
[docs]
def save_model(
model: TModel,
model_path: str | os.PathLike,
/,
*,
external_data: str | os.PathLike | None = None,
size_threshold: int = 0,
max_shard_size: int | str | None = None,
) -> None:
"""Save an ONNX model to a file with external data in a safetensors file.
.. versionadded:: 1.3.0
Added the function.
Args:
model: ONNX model to save.
model_path: Path to the ONNX model file. E.g. "model.onnx".
external_data: Path to the safetensors file relative to the ONNX model file.
E.g. "model.safetensors". If not provided, it will be derived from
the model_path by replacing the extension with ".safetensors".
size_threshold: Minimum size in bytes for a tensor to be saved.
Default is 0, which saves all tensors.
max_shard_size: Maximum size in bytes for a checkpoint before being sharded.
If expressed as a string, needs to be digits followed by a unit
(like "5GB" or "100MB"). If None, no sharding is performed.
When sharding is enabled, multiple safetensors files will be created
with names like "model-00001-of-00003.safetensors", and an index
file "model.safetensors.index.json" will be created to map tensors
to their respective shard files.
Raises:
ValueError: If external_data does not end with ".safetensors".
"""
# Derive external_data from model_path if not provided
if external_data is None:
model_path_str = str(model_path)
# Get the base name without extension
if "." in os.path.basename(model_path_str):
base_name = os.path.splitext(os.path.basename(model_path_str))[0]
else:
base_name = os.path.basename(model_path_str)
external_data = f"{base_name}.safetensors"
if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
else:
model_ir = model
updated_model = save_file(
model_ir,
external_data,
os.path.dirname(model_path),
size_threshold=size_threshold,
replace_data=True,
max_shard_size=max_shard_size,
)
ir.save(updated_model, model_path)
def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]], int]:
"""Read the header of a safetensors file.
Args:
file: The safetensors file to read.
Returns:
The header of the safetensors file.
"""
file.seek(0)
header_size = struct.unpack_from("i", file.read(_HEADER_SIZE_NUMBER_SIZE))[0]
header = file.read(header_size)
return json.loads(header.decode("utf-8")), header_size
def _read_safetensors(
location: str | os.PathLike, base_dir: str | os.PathLike
) -> dict[str, ir.ExternalTensor]:
"""Read a safetensors file.
Args:
location: The safetensors file to read.
base_dir: Directory where the ONNX model file is stored.
Returns:
The contents of the safetensors file.
"""
path = os.path.join(base_dir, location)
with open(path, "rb") as file:
header, header_size = _read_safetensors_header(file)
tensors = {}
for name, metadata in header.items():
if name == "__metadata__":
continue
offset = metadata["data_offsets"][0] + header_size + _HEADER_SIZE_NUMBER_SIZE
length = metadata["data_offsets"][1] - metadata["data_offsets"][0]
tensors[name] = ir.ExternalTensor(
location=location,
offset=offset,
length=length,
dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]],
shape=ir.Shape(metadata["shape"]),
name=name,
base_dir=base_dir,
)
return tensors
def _check_tensors_match(
model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor
):
"""Check if two tensors match.
Args:
model_tensor: Tensor from the model.
safe_tensor: Tensor from the safetensors file.
Raises:
ValueError: If the tensors do not match.
"""
if _is_4bit(model_tensor.dtype):
if safe_tensor.dtype != ir.DataType.UINT8:
raise ValueError(
f"The tensor from safetensors has dtype: {safe_tensor.dtype}, but it must be UINT8 to "
f"represent the dtype of the tensor in the model: {model_tensor.dtype}."
)
if model_tensor.nbytes != safe_tensor.nbytes:
raise ValueError(
f"The tensor from safetensors has size: {safe_tensor.nbytes} bytes, "
f"which does not match the size of the tensor in the model: {model_tensor.nbytes} bytes."
)
return
if _is_8bit_float(model_tensor.dtype):
if (
not _is_8bit_float(safe_tensor.dtype)
and safe_tensor.dtype != ir.DataType.UINT8
):
raise ValueError(
f"The tensor from safetensors has dtype: {safe_tensor.dtype}, but it must be UINT8 to "
f"represent the dtype of the tensor in the model: {model_tensor.dtype}."
)
elif model_tensor.dtype != safe_tensor.dtype:
raise ValueError(
f"The tensor from safetensors has dtype: {safe_tensor.dtype}, "
f"which does not match the dtype of the tensor in the model: {model_tensor.dtype}."
)
if model_tensor.shape != safe_tensor.shape:
raise ValueError(
f"The tensor from safetensors has shape: {safe_tensor.shape}, "
f"which does not match the shape of the tensor in the model: {model_tensor.shape}."
)
def _migrate_tensor_shape_dtype(
model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor
) -> ir.ExternalTensor:
"""Migrate the shape and dtype of a tensor.
Args:
model_tensor: The tensor to migrate.
safe_tensor: The tensor to migrate to.
Returns:
The migrated tensor.
"""
if model_tensor.dtype in {
# ir.DataType.FLOAT8E4M3FN,
# ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2FNUZ,
} or _is_4bit(model_tensor.dtype):
return ir.ExternalTensor(
location=safe_tensor.location,
offset=safe_tensor.offset,
length=safe_tensor.length,
dtype=model_tensor.dtype,
shape=model_tensor.shape,
name=safe_tensor.name,
base_dir=safe_tensor.base_dir,
)
return safe_tensor