import os import sys import torch import argparse import traceback # Menambahkan direktori aktif ke path agar lib dapat diimpor sys.path.append(os.getcwd()) from lib.infer_pack.models_onnx import SynthesizerTrnMsNSFsidM def export_model_to_onnx(model_path, output_onnx_path): print(f"Loading PyTorch checkpoint from: {model_path}") try: # Load checkpoint ke CPU cpt = torch.load(model_path, map_location="cpu") except Exception as e: print(f"Error loading checkpoint: {e}") return False # Ambil metadata model tgt_sr = cpt["config"][-1] # Ambil jumlah spk dari bobot embedding if "emb_g.weight" in cpt["weight"]: n_spk = cpt["weight"]["emb_g.weight"].shape[0] else: n_spk = 1 # Sesuaikan config spk_embed_dim cpt["config"][-3] = n_spk version = cpt.get("version", "v1") if_f0 = cpt.get("f0", 1) print(f"Model Version: {version}") print(f"Pitch (F0) Enabled: {if_f0}") print(f"Target Sample Rate: {tgt_sr} Hz") print(f"Number of Speakers: {n_spk}") # Inisialisasi model khusus ONNX (SynthesizerTrnMsNSFsidM) # is_half set ke False untuk ekspor dalam FP32 demi kompabilitas ONNX Runtime yang stabil try: net_g = SynthesizerTrnMsNSFsidM(*cpt["config"], version=version, is_half=False) # Hapus bagian encoder posterior yang tidak digunakan saat inferensi if hasattr(net_g, "enc_q"): del net_g.enc_q # Muat bobot model, biarkan strict=False agar mengabaikan enc_q yang dihapus net_g.load_state_dict(cpt["weight"], strict=False) net_g.eval() print("PyTorch model loaded successfully. Preparing dummy inputs...") except Exception as e: print(f"Failed to initialize RVC ONNX model class: {e}") traceback.print_exc() return False # Siapkan dummy inputs untuk tracing ekspor test_len = 10 # Panjang sekuens dummy feat_dim = 256 if version == "v1" else 768 phone = torch.randn(1, test_len, feat_dim, dtype=torch.float32) phone_lengths = torch.tensor([test_len], dtype=torch.int64) pitch = torch.randint(1, 254, (1, test_len), dtype=torch.int64) nsff0 = torch.randn(1, test_len, dtype=torch.float32) g = torch.tensor([0], dtype=torch.int64) # Speaker ID 0 rnd = torch.randn(1, 192, test_len, dtype=torch.float32) input_names = ["phone", "phone_lengths", "pitch", "nsff0", "g", "rnd"] output_names = ["audio"] dynamic_axes = { "phone": {1: "length"}, "pitch": {1: "length"}, "nsff0": {1: "length"}, "rnd": {2: "length"}, "audio": {1: "audio_length"} } print(f"Exporting model to ONNX format at: {output_onnx_path}") try: torch.onnx.export( net_g, (phone, phone_lengths, pitch, nsff0, g, rnd), output_onnx_path, opset_version=17, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, verbose=False ) print("ONNX model exported successfully!") return True except Exception as e: print(f"Error during ONNX export: {e}") traceback.print_exc() return False if __name__ == "__main__": parser = argparse.ArgumentParser(description="Export RVC PyTorch .pth model to ONNX") parser.add_argument("--model_name", type=str, required=True, help="Nama model di folder weights (nama sub-folder)") parser.add_argument("--output", type=str, default="", help="Path output file ONNX (opsional)") args = parser.parse_args() model_root = "weights" model_dir = os.path.join(model_root, args.model_name) if not os.path.isdir(model_dir): print(f"Error: Folder '{model_dir}' tidak ditemukan!") sys.exit(1) pth_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")] if not pth_files: print(f"Error: Tidak ada berkas .pth di dalam folder '{model_dir}'!") sys.exit(1) pth_path = os.path.join(model_dir, pth_files[0]) if args.output: onnx_path = args.output else: # Default simpan di dalam sub-folder weights yang sama onnx_name = os.path.splitext(pth_files[0])[0] + ".onnx" onnx_path = os.path.join(model_dir, onnx_name) success = export_model_to_onnx(pth_path, onnx_path) if success: print(f"\nSelesai! Model ONNX disimpan di: {onnx_path}") else: print("\nEkspor gagal!") sys.exit(1)