Files
onnx-voice-changer/lib/export_onnx.py
T

135 lines
4.5 KiB
Python

import os
import sys
import torch
import argparse
import traceback
# Menambahkan direktori aktif ke path agar lib dapat diimpor
sys.path.append(os.getcwd())
from lib.infer_pack.models_onnx import SynthesizerTrnMsNSFsidM
def export_model_to_onnx(model_path, output_onnx_path):
print(f"Loading PyTorch checkpoint from: {model_path}")
try:
# Load checkpoint ke CPU
cpt = torch.load(model_path, map_location="cpu")
except Exception as e:
print(f"Error loading checkpoint: {e}")
return False
# Ambil metadata model
tgt_sr = cpt["config"][-1]
# Ambil jumlah spk dari bobot embedding
if "emb_g.weight" in cpt["weight"]:
n_spk = cpt["weight"]["emb_g.weight"].shape[0]
else:
n_spk = 1
# Sesuaikan config spk_embed_dim
cpt["config"][-3] = n_spk
version = cpt.get("version", "v1")
if_f0 = cpt.get("f0", 1)
print(f"Model Version: {version}")
print(f"Pitch (F0) Enabled: {if_f0}")
print(f"Target Sample Rate: {tgt_sr} Hz")
print(f"Number of Speakers: {n_spk}")
# Inisialisasi model khusus ONNX (SynthesizerTrnMsNSFsidM)
# is_half set ke False untuk ekspor dalam FP32 demi kompabilitas ONNX Runtime yang stabil
try:
net_g = SynthesizerTrnMsNSFsidM(*cpt["config"], version=version, is_half=False)
# Hapus bagian encoder posterior yang tidak digunakan saat inferensi
if hasattr(net_g, "enc_q"):
del net_g.enc_q
# Muat bobot model, biarkan strict=False agar mengabaikan enc_q yang dihapus
net_g.load_state_dict(cpt["weight"], strict=False)
net_g.eval()
print("PyTorch model loaded successfully. Preparing dummy inputs...")
except Exception as e:
print(f"Failed to initialize RVC ONNX model class: {e}")
traceback.print_exc()
return False
# Siapkan dummy inputs untuk tracing ekspor
test_len = 10 # Panjang sekuens dummy
feat_dim = 256 if version == "v1" else 768
phone = torch.randn(1, test_len, feat_dim, dtype=torch.float32)
phone_lengths = torch.tensor([test_len], dtype=torch.int64)
pitch = torch.randint(1, 254, (1, test_len), dtype=torch.int64)
nsff0 = torch.randn(1, test_len, dtype=torch.float32)
g = torch.tensor([0], dtype=torch.int64) # Speaker ID 0
rnd = torch.randn(1, 192, test_len, dtype=torch.float32)
input_names = ["phone", "phone_lengths", "pitch", "nsff0", "g", "rnd"]
output_names = ["audio"]
dynamic_axes = {
"phone": {1: "length"},
"pitch": {1: "length"},
"nsff0": {1: "length"},
"rnd": {2: "length"},
"audio": {1: "audio_length"}
}
print(f"Exporting model to ONNX format at: {output_onnx_path}")
try:
torch.onnx.export(
net_g,
(phone, phone_lengths, pitch, nsff0, g, rnd),
output_onnx_path,
opset_version=17,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
verbose=False
)
print("ONNX model exported successfully!")
return True
except Exception as e:
print(f"Error during ONNX export: {e}")
traceback.print_exc()
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export RVC PyTorch .pth model to ONNX")
parser.add_argument("--model_name", type=str, required=True, help="Nama model di folder weights (nama sub-folder)")
parser.add_argument("--output", type=str, default="", help="Path output file ONNX (opsional)")
args = parser.parse_args()
model_root = "weights"
model_dir = os.path.join(model_root, args.model_name)
if not os.path.isdir(model_dir):
print(f"Error: Folder '{model_dir}' tidak ditemukan!")
sys.exit(1)
pth_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
if not pth_files:
print(f"Error: Tidak ada berkas .pth di dalam folder '{model_dir}'!")
sys.exit(1)
pth_path = os.path.join(model_dir, pth_files[0])
if args.output:
onnx_path = args.output
else:
# Default simpan di dalam sub-folder weights yang sama
onnx_name = os.path.splitext(pth_files[0])[0] + ".onnx"
onnx_path = os.path.join(model_dir, onnx_name)
success = export_model_to_onnx(pth_path, onnx_path)
if success:
print(f"\nSelesai! Model ONNX disimpan di: {onnx_path}")
else:
print("\nEkspor gagal!")
sys.exit(1)