725 lines
29 KiB
Python
725 lines
29 KiB
Python
import os
|
|
import sys
|
|
import types
|
|
import json
|
|
|
|
# --- FIX FOR PATH COLLISION BETWEEN modules.py AND modules/ DIRECTORY ---
|
|
# Kita memaksa Python untuk mendaftarkan 'lib.infer_pack.modules' sebagai package directory
|
|
# alih-alih file-module, sehingga sub-impor (seperti F0Predictor) dapat dimuat dengan sukses
|
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
modules_path = os.path.join(base_dir, "lib", "infer_pack", "modules")
|
|
if os.path.isdir(modules_path):
|
|
modules_pkg = types.ModuleType("lib.infer_pack.modules")
|
|
modules_pkg.__path__ = [modules_path]
|
|
modules_pkg.__file__ = os.path.join(modules_path, "__init__.py")
|
|
sys.modules["lib.infer_pack.modules"] = modules_pkg
|
|
# ------------------------------------------------------------------------
|
|
import time
|
|
import asyncio
|
|
import logging
|
|
import traceback
|
|
import argparse
|
|
import threading
|
|
import numpy as np
|
|
import torch
|
|
import onnxruntime as ort
|
|
|
|
# Add parent directories to sys.path so we can import lib
|
|
sys.path.append(os.getcwd())
|
|
|
|
from lib.infer_pack.onnx_inference import OnnxRVC, get_f0_predictor
|
|
|
|
# Set logging
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
logger = logging.getLogger("RVC-Realtime-Server")
|
|
|
|
# Thread pool for audio processing (1 worker = sequential, but non-blocking to event loop)
|
|
import concurrent.futures
|
|
_audio_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="rvc-audio")
|
|
|
|
# Global instances cache
|
|
current_rvc_onnx = None
|
|
current_model_key = None
|
|
model_root = "weights"
|
|
pretrained_root = "pretrained"
|
|
|
|
# Patch torch.load to default to weights_only=False for compatibility
|
|
original_load = torch.load
|
|
def patched_load(*args, **kwargs):
|
|
if "weights_only" not in kwargs:
|
|
kwargs["weights_only"] = False
|
|
return original_load(*args, **kwargs)
|
|
torch.load = patched_load
|
|
|
|
def get_onnx_models():
|
|
models = []
|
|
if os.path.exists(model_root):
|
|
for d in os.listdir(model_root):
|
|
d_path = os.path.join(model_root, d)
|
|
if os.path.isdir(d_path):
|
|
onnx_files = [f for f in os.listdir(d_path) if f.endswith(".onnx")]
|
|
if onnx_files:
|
|
models.append(d)
|
|
models.sort()
|
|
return models
|
|
|
|
def get_model_metadata(model_name):
|
|
model_dir = os.path.join(model_root, model_name)
|
|
onnx_files = [f for f in os.listdir(model_dir) if f.endswith(".onnx")]
|
|
pth_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
|
|
|
|
onnx_path = os.path.join(model_dir, onnx_files[0])
|
|
sr = 40000 # default
|
|
if pth_files:
|
|
try:
|
|
pth_path = os.path.join(model_dir, pth_files[0])
|
|
cpt = torch.load(pth_path, map_location="cpu")
|
|
sr = cpt["config"][-1]
|
|
logger.info(f"Detected sample rate from .pth: {sr} Hz")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load sample rate from .pth, using default 40000: {e}")
|
|
|
|
version = "v2"
|
|
vec_path = "vec-768-layer-12"
|
|
try:
|
|
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
|
|
feat_dim = sess.get_inputs()[0].shape[2]
|
|
if feat_dim == 256:
|
|
version = "v1"
|
|
vec_path = "vec-256-layer-12"
|
|
logger.info("Detected RVC Model Version: v1 (feat_dim = 256)")
|
|
else:
|
|
version = "v2"
|
|
vec_path = "vec-768-layer-12"
|
|
logger.info("Detected RVC Model Version: v2 (feat_dim = 768)")
|
|
except Exception as e:
|
|
logger.error(f"Error auto-detecting model version from ONNX: {e}")
|
|
|
|
return onnx_path, sr, vec_path, version
|
|
|
|
class RealtimeVoiceChanger:
|
|
def __init__(self):
|
|
self.processor = None
|
|
self.model_name = ""
|
|
self.f0_up_key = 0
|
|
self.f0_method = "pm"
|
|
self.device = "cuda"
|
|
self.input_sr = 44100
|
|
self.noise_gate_db = -40.0
|
|
self.input_gain = 1.0
|
|
self.output_gain = 1.0
|
|
|
|
# Audio sliding buffers
|
|
self.input_buffer = np.zeros(0, dtype=np.float32)
|
|
self.history_duration = 0.30 # 300ms history = enough context even at 8192 samples @ 48kHz
|
|
self.target_sr = 40000
|
|
self.vec_path = "vec-768-layer-12"
|
|
self.version = "v2"
|
|
self.f0_predictors = {} # Cache to reuse pitch predictors instead of recreating them on every chunk
|
|
|
|
# Server Hardware Routing properties
|
|
self.local_stream = None
|
|
self.routing_mode = "browser"
|
|
self.input_device = None
|
|
self.output_device = None
|
|
self.chunk_size = 8192
|
|
self.loop = None
|
|
self.ws_client = None
|
|
self.visualizer_queue = None
|
|
|
|
# High-pass filter coefficients cache
|
|
self.hpf_b = None
|
|
self.hpf_a = None
|
|
|
|
def load_model(self, model_name, device):
|
|
global current_rvc_onnx, current_model_key
|
|
|
|
# Resolve the device provider
|
|
if device == "cuda" and not torch.cuda.is_available():
|
|
logger.warning("CUDA is not available, falling back to CPU")
|
|
device = "cpu"
|
|
|
|
model_key = f"{model_name}_{device}"
|
|
if current_rvc_onnx is None or current_model_key != model_key:
|
|
logger.info(f"Loading RVC model '{model_name}' on {device}...")
|
|
onnx_path, sr, vec_path, version = get_model_metadata(model_name)
|
|
|
|
# Ensure HuBERT model exists
|
|
full_vec_path = os.path.join(pretrained_root, f"{vec_path}.onnx")
|
|
if not os.path.exists(full_vec_path):
|
|
raise FileNotFoundError(f"ContentVec ONNX not found at: {full_vec_path}")
|
|
|
|
current_rvc_onnx = OnnxRVC(
|
|
model_path=onnx_path,
|
|
sr=sr,
|
|
hop_size=512,
|
|
vec_path=vec_path,
|
|
device=device
|
|
)
|
|
current_model_key = model_key
|
|
logger.info("Model loaded successfully")
|
|
|
|
self.processor = current_rvc_onnx
|
|
self.target_sr = self.processor.sampling_rate
|
|
self.model_name = model_name
|
|
self.device = device
|
|
|
|
def set_config(self, config):
|
|
logger.info(f"Updating config: {config}")
|
|
|
|
# Update config fields
|
|
self.f0_up_key = int(config.get("f0_up_key", self.f0_up_key))
|
|
self.f0_method = config.get("f0_method", self.f0_method)
|
|
self.input_sr = int(config.get("input_sr", self.input_sr))
|
|
self.noise_gate_db = float(config.get("noise_gate", self.noise_gate_db))
|
|
self.input_gain = float(config.get("input_gain", self.input_gain))
|
|
self.output_gain = float(config.get("output_gain", self.output_gain))
|
|
|
|
model_name = config.get("model_name", self.model_name)
|
|
device = config.get("device", self.device)
|
|
|
|
if not self.model_name or model_name != self.model_name or device != self.device:
|
|
self.load_model(model_name, device)
|
|
|
|
# Reset input buffer if input samplerate changed
|
|
history_samples = int(self.history_duration * self.input_sr)
|
|
if len(self.input_buffer) != history_samples:
|
|
self.input_buffer = np.zeros(history_samples, dtype=np.float32)
|
|
|
|
# Design a 1st order Butterworth high-pass filter at 80Hz to eliminate low-frequency static rumbling/hums
|
|
try:
|
|
from scipy import signal
|
|
nyq = 0.5 * self.input_sr
|
|
normal_cutoff = 80.0 / nyq
|
|
self.hpf_b, self.hpf_a = signal.butter(1, normal_cutoff, btype='high', analog=False)
|
|
except Exception as e:
|
|
logger.error(f"Failed to design high-pass filter: {e}")
|
|
self.hpf_b, self.hpf_a = None, None
|
|
|
|
def apply_noise_gate(self, audio):
|
|
# Calculate RMS energy of the audio chunk
|
|
rms = np.sqrt(np.mean(audio**2)) + 1e-9
|
|
rms_db = 20 * np.log10(rms)
|
|
|
|
if rms_db < self.noise_gate_db:
|
|
return np.zeros_like(audio)
|
|
return audio
|
|
|
|
def resample(self, audio, orig_sr, target_sr):
|
|
if orig_sr == target_sr:
|
|
return audio
|
|
|
|
# Fast linear interpolation resampling
|
|
duration = len(audio) / orig_sr
|
|
num_target_samples = int(duration * target_sr)
|
|
x_orig = np.linspace(0, duration, len(audio))
|
|
x_target = np.linspace(0, duration, num_target_samples)
|
|
return np.interp(x_target, x_orig, audio).astype(np.float32)
|
|
|
|
def process_audio_chunk(self, raw_chunk):
|
|
"""
|
|
Process a raw input Float32 PCM audio chunk in memory with sliding window.
|
|
"""
|
|
if self.processor is None:
|
|
return raw_chunk
|
|
|
|
t_start = time.time()
|
|
|
|
# 1. Apply High-pass filter (80Hz Low-cut) to eliminate low-frequency background rumbles and AC hum
|
|
chunk = raw_chunk
|
|
if self.hpf_b is not None and self.hpf_a is not None:
|
|
try:
|
|
from scipy import signal
|
|
chunk = signal.lfilter(self.hpf_b, self.hpf_a, chunk).astype(np.float32)
|
|
except Exception as e:
|
|
pass
|
|
|
|
# 2. Apply Input Gain & Noise Gate
|
|
chunk = chunk * self.input_gain
|
|
chunk = self.apply_noise_gate(chunk)
|
|
|
|
# If chunk is pure silence from the gate, bypass inference immediately to save CPU!
|
|
if np.max(np.abs(chunk)) < 1e-6:
|
|
output_len = int(len(raw_chunk) * (self.target_sr / self.input_sr))
|
|
return np.zeros(output_len, dtype=np.float32)
|
|
|
|
t_gate = time.time()
|
|
|
|
# 3. Manage Sliding Window Buffer
|
|
self.input_buffer = np.append(self.input_buffer[len(chunk):], chunk)
|
|
|
|
# Append 120ms of silence at the end to push RVC convolution edge fading into the padded future (no edge distortion!)
|
|
future_samples = int(0.12 * self.input_sr)
|
|
full_input_audio = np.append(self.input_buffer, np.zeros(future_samples, dtype=np.float32))
|
|
|
|
# 4. Resample full segment to 16kHz for HuBERT and RMVPE
|
|
wav16k = self.resample(full_input_audio, self.input_sr, 16000)
|
|
|
|
t_resample_in = time.time()
|
|
|
|
# 4. Generate RVC ONNX inputs in-memory
|
|
hubert = self.processor.vec_model(wav16k)
|
|
hubert = np.repeat(hubert, 2, axis=2).transpose(0, 2, 1).astype(np.float32)
|
|
hubert_length = hubert.shape[1]
|
|
|
|
t_hubert = time.time()
|
|
|
|
# Initialize and cache pitch predictor (extract at 16kHz for 3x performance boost!)
|
|
predictor_key = f"{self.f0_method}_{self.device}_16000"
|
|
if predictor_key not in self.f0_predictors:
|
|
logger.info(f"Initializing and caching 16kHz F0 Predictor '{self.f0_method}' on {self.device}...")
|
|
hop_16k = int(self.processor.hop_size * 16000 / self.target_sr)
|
|
self.f0_predictors[predictor_key] = get_f0_predictor(
|
|
self.f0_method,
|
|
hop_length=hop_16k,
|
|
sampling_rate=16000,
|
|
threshold=0.02,
|
|
device=self.device,
|
|
is_half=True if self.device == "cuda" else False
|
|
)
|
|
f0_predictor = self.f0_predictors[predictor_key]
|
|
|
|
# Calculate pitch on 16kHz audio for all methods (massive CPU speed up!)
|
|
pitchf = f0_predictor.compute_f0(wav16k, hubert_length)
|
|
|
|
t_f0 = time.time()
|
|
|
|
# Pitch transpose
|
|
pitchf = pitchf * (2 ** (self.f0_up_key / 12))
|
|
|
|
# Pitch binning for RVC
|
|
f0_min = 50
|
|
f0_max = 1100
|
|
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
|
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
|
f0_mel = 1127 * np.log(1 + pitchf / 700)
|
|
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
|
|
f0_mel[f0_mel <= 1] = 1
|
|
f0_mel[f0_mel > 255] = 255
|
|
pitch = np.rint(f0_mel).astype(np.int64)
|
|
|
|
pitchf = pitchf.reshape(1, len(pitchf)).astype(np.float32)
|
|
pitch = pitch.reshape(1, len(pitch))
|
|
ds = np.array([0]).astype(np.int64) # sid = 0
|
|
rnd = np.random.randn(1, 192, hubert_length).astype(np.float32)
|
|
hubert_length_tensor = np.array([hubert_length]).astype(np.int64)
|
|
|
|
# 5. Run synthesis
|
|
out_wav = self.processor.forward(hubert, hubert_length_tensor, pitch, pitchf, ds, rnd).squeeze()
|
|
out_wav = out_wav.astype(np.float32) / 32767.0 # Normalize back to [-1.0, 1.0] float32
|
|
|
|
t_synth = time.time()
|
|
|
|
# 6. Extract only the newly converted chunk, discarding history and future padding.
|
|
# The new chunk starts at (buffer_size - chunk_size) in the updated input_buffer.
|
|
# Use exact ratio target_sr/input_sr for clean integer math, not out_wav/full_input_audio
|
|
# which can drift due to future silence padding.
|
|
sr_ratio = self.target_sr / self.input_sr
|
|
history_in_buffer = len(self.input_buffer) - len(chunk) # samples of old audio before new chunk
|
|
start_idx = int(history_in_buffer * sr_ratio)
|
|
end_idx = int((history_in_buffer + len(chunk)) * sr_ratio)
|
|
|
|
# Safety clamp
|
|
start_idx = max(0, min(start_idx, len(out_wav) - 1))
|
|
end_idx = max(start_idx + 1, min(end_idx, len(out_wav)))
|
|
|
|
# Extract the converted chunk safely from the middle (flawless, clear, continuous voice!)
|
|
output_chunk = out_wav[start_idx:end_idx]
|
|
|
|
# Resample back to target output samples to match browser playback rate perfectly
|
|
target_chunk_len = int(len(chunk) * (self.target_sr / self.input_sr))
|
|
output_chunk = self.resample(output_chunk, len(output_chunk), target_chunk_len)
|
|
|
|
# Apply output gain
|
|
output_chunk = output_chunk * self.output_gain
|
|
|
|
# Ensure we don't clip
|
|
output_chunk = np.clip(output_chunk, -1.0, 1.0)
|
|
|
|
t_end = time.time()
|
|
|
|
d_gate = (t_gate - t_start) * 1000
|
|
d_res_in = (t_resample_in - t_gate) * 1000
|
|
d_hubert = (t_hubert - t_resample_in) * 1000
|
|
d_f0 = (t_f0 - t_hubert) * 1000
|
|
d_synth = (t_synth - t_f0) * 1000
|
|
d_res_out = (t_end - t_synth) * 1000
|
|
t_elapsed = (t_end - t_start) * 1000
|
|
|
|
logger.info(
|
|
f"Chunk Profile: total={t_elapsed:.1f}ms | gate={d_gate:.1f}ms | res_in={d_res_in:.1f}ms | "
|
|
f"hubert={d_hubert:.1f}ms | f0={d_f0:.1f}ms | synth={d_synth:.1f}ms | res_out={d_res_out:.1f}ms"
|
|
)
|
|
|
|
return output_chunk
|
|
|
|
def start_local_stream(self, loop, ws_client):
|
|
import sounddevice as sd
|
|
self.loop = loop
|
|
self.ws_client = ws_client
|
|
self.visualizer_queue = asyncio.Queue()
|
|
|
|
if self.local_stream is not None:
|
|
self.stop_local_stream()
|
|
|
|
if self.input_device is None:
|
|
self.input_device = sd.default.device[0]
|
|
if self.output_device is None:
|
|
self.output_device = sd.default.device[1]
|
|
|
|
input_info = sd.query_devices(self.input_device)
|
|
output_info = sd.query_devices(self.output_device)
|
|
|
|
input_sr = int(input_info["default_samplerate"])
|
|
logger.info(f"Starting Server Hardware Stream: Input='{input_info['name']}' ({input_sr}Hz) | Output='{output_info['name']}' ({self.target_sr}Hz)")
|
|
|
|
self.set_config({
|
|
"input_sr": input_sr
|
|
})
|
|
|
|
def audio_callback(indata, outdata, frames, time_info, status):
|
|
if status:
|
|
logger.warning(f"Hardware Audio Callback Status: {status}")
|
|
|
|
raw_chunk = indata[:, 0].copy()
|
|
output_chunk = self.process_audio_chunk(raw_chunk)
|
|
|
|
if len(output_chunk) < frames:
|
|
outdata[:, 0] = np.pad(output_chunk, (0, frames - len(output_chunk)), "constant")
|
|
else:
|
|
outdata[:, 0] = output_chunk[:frames]
|
|
|
|
# Send waveform chunks to WebSocket safely using loop.call_soon_threadsafe
|
|
if self.ws_client is not None:
|
|
loop.call_soon_threadsafe(
|
|
self.visualizer_queue.put_nowait,
|
|
(raw_chunk.copy(), output_chunk.copy())
|
|
)
|
|
|
|
try:
|
|
self.local_stream = sd.Stream(
|
|
device=(self.input_device, self.output_device),
|
|
samplerate=self.target_sr,
|
|
blocksize=self.chunk_size,
|
|
channels=1,
|
|
dtype="float32",
|
|
callback=audio_callback
|
|
)
|
|
self.local_stream.start()
|
|
logger.info("Server Hardware Stream active and processing locally!")
|
|
except Exception as e:
|
|
logger.error(f"Failed to start hardware stream: {e}")
|
|
raise e
|
|
|
|
def stop_local_stream(self):
|
|
if self.local_stream is not None:
|
|
try:
|
|
self.local_stream.stop()
|
|
self.local_stream.close()
|
|
logger.info("Server Hardware Stream stopped successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Error stopping hardware stream: {e}")
|
|
self.local_stream = None
|
|
self.visualizer_queue = None
|
|
|
|
# --- WEBSOCKET SERVER IMPLEMENTATION ---
|
|
async def websocket_handler(websocket):
|
|
logger.info("New WebSocket client connected")
|
|
rvc = RealtimeVoiceChanger()
|
|
loop = asyncio.get_running_loop()
|
|
|
|
# --- Pipeline queues ---
|
|
# input_queue : raw bytes from browser mic
|
|
# output_queue: processed bytes ready to send back
|
|
# maxsize=3 = ~3 chunk durations of buffer; provides backpressure without unbounded memory
|
|
input_queue = asyncio.Queue(maxsize=3)
|
|
output_queue = asyncio.Queue(maxsize=3)
|
|
|
|
# --- STAGE 1: Receiver ---
|
|
# Reads ALL WebSocket messages immediately (never blocks on processing).
|
|
# Handles JSON config inline; binary audio chunks go to input_queue.
|
|
async def receiver_task():
|
|
try:
|
|
async for message in websocket:
|
|
if isinstance(message, str):
|
|
try:
|
|
data = json.loads(message)
|
|
if data.get("type") == "config":
|
|
new_routing_mode = data.get("routing_mode", rvc.routing_mode)
|
|
new_input_device = data.get("input_device", rvc.input_device)
|
|
new_output_device = data.get("output_device", rvc.output_device)
|
|
new_chunk_size = int(data.get("chunk_size", rvc.chunk_size))
|
|
|
|
if new_input_device is not None: rvc.input_device = int(new_input_device)
|
|
if new_output_device is not None: rvc.output_device = int(new_output_device)
|
|
rvc.chunk_size = new_chunk_size
|
|
rvc.set_config(data)
|
|
|
|
if new_routing_mode == "hardware":
|
|
if rvc.routing_mode != "hardware" or rvc.local_stream is None:
|
|
rvc.routing_mode = "hardware"
|
|
rvc.start_local_stream(loop, websocket)
|
|
else:
|
|
if rvc.routing_mode == "hardware":
|
|
rvc.stop_local_stream()
|
|
rvc.routing_mode = "browser"
|
|
|
|
response = {
|
|
"type": "config_success",
|
|
"model_name": rvc.model_name,
|
|
"target_sr": rvc.target_sr,
|
|
"f0_method": rvc.f0_method,
|
|
"f0_up_key": rvc.f0_up_key,
|
|
"device": rvc.device,
|
|
"routing_mode": rvc.routing_mode
|
|
}
|
|
await websocket.send(json.dumps(response))
|
|
except Exception as e:
|
|
logger.error(f"Config parse error: {e}")
|
|
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
|
|
|
elif isinstance(message, bytes):
|
|
if rvc.routing_mode == "hardware":
|
|
continue
|
|
if rvc.processor is None:
|
|
# Prepend 0.0 processing time to echoed message
|
|
input_chunk = np.frombuffer(message, dtype=np.float32)
|
|
payload = np.empty(len(input_chunk) + 1, dtype=np.float32)
|
|
payload[0] = 0.0
|
|
payload[1:] = input_chunk
|
|
await websocket.send(payload.tobytes())
|
|
continue
|
|
# Put chunk in queue — await here means we yield if queue is full (backpressure)
|
|
await input_queue.put(message)
|
|
except Exception as e:
|
|
logger.error(f"Receiver error: {e}")
|
|
finally:
|
|
# Signal downstream stages to stop
|
|
await input_queue.put(None)
|
|
|
|
# --- STAGE 2: Processor ---
|
|
# Pulls from input_queue, processes in thread (non-blocking to event loop), pushes to output_queue.
|
|
# Sequential (1 executor worker) = output order matches input order.
|
|
async def processor_task():
|
|
try:
|
|
while True:
|
|
item = await input_queue.get()
|
|
if item is None:
|
|
break # shutdown signal
|
|
input_chunk = np.frombuffer(item, dtype=np.float32).copy()
|
|
t_start = time.time()
|
|
output_chunk = await loop.run_in_executor(
|
|
_audio_executor, rvc.process_audio_chunk, input_chunk
|
|
)
|
|
t_elapsed = (time.time() - t_start) * 1000
|
|
|
|
# Prepend the elapsed processing time to the audio chunk bytes
|
|
payload = np.empty(len(output_chunk) + 1, dtype=np.float32)
|
|
payload[0] = t_elapsed
|
|
payload[1:] = output_chunk
|
|
await output_queue.put(payload.tobytes())
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"Processor error: {e}")
|
|
finally:
|
|
await output_queue.put(None)
|
|
|
|
# --- STAGE 3: Sender ---
|
|
# Pulls processed audio from output_queue and sends to client.
|
|
async def sender_task():
|
|
try:
|
|
while True:
|
|
item = await output_queue.get()
|
|
if item is None:
|
|
break # shutdown signal
|
|
await websocket.send(item)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"Sender error: {e}")
|
|
|
|
# --- Visualizer (hardware mode only) ---
|
|
async def visualizer_sender_loop():
|
|
try:
|
|
while True:
|
|
if rvc.routing_mode == "hardware" and rvc.visualizer_queue is not None:
|
|
try:
|
|
raw_chunk, output_chunk = await asyncio.wait_for(
|
|
rvc.visualizer_queue.get(), timeout=0.1
|
|
)
|
|
payload = {
|
|
"type": "visualizer",
|
|
"input": raw_chunk.tolist(),
|
|
"output": output_chunk.tolist()
|
|
}
|
|
await websocket.send(json.dumps(payload))
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
else:
|
|
await asyncio.sleep(0.05)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"Visualizer sender error: {e}")
|
|
|
|
# --- Send device list on connect ---
|
|
import sounddevice as sd
|
|
devices_list = []
|
|
try:
|
|
for idx, d in enumerate(sd.query_devices()):
|
|
devices_list.append({
|
|
"id": idx,
|
|
"name": d["name"],
|
|
"max_input_channels": d["max_input_channels"],
|
|
"max_output_channels": d["max_output_channels"],
|
|
"default_samplerate": d["default_samplerate"]
|
|
})
|
|
default_input = sd.default.device[0]
|
|
default_output = sd.default.device[1]
|
|
except Exception as e:
|
|
logger.error(f"Failed to query server audio devices: {e}")
|
|
devices_list = []
|
|
default_input = -1
|
|
default_output = -1
|
|
|
|
await websocket.send(json.dumps({
|
|
"type": "init_devices",
|
|
"devices": devices_list,
|
|
"default_input": default_input,
|
|
"default_output": default_output
|
|
}))
|
|
|
|
# --- Run all pipeline stages concurrently ---
|
|
vis_task = asyncio.create_task(visualizer_sender_loop())
|
|
proc_task = asyncio.create_task(processor_task())
|
|
send_task = asyncio.create_task(sender_task())
|
|
|
|
try:
|
|
await receiver_task() # runs until websocket closes
|
|
except Exception as e:
|
|
logger.error(f"WebSocket handler error: {e}")
|
|
finally:
|
|
vis_task.cancel()
|
|
proc_task.cancel()
|
|
send_task.cancel()
|
|
rvc.stop_local_stream()
|
|
logger.info("WebSocket client disconnected, pipeline cleaned up.")
|
|
|
|
|
|
async def start_websocket_server(host, port):
|
|
import websockets
|
|
logger.info(f"Starting WebSocket server on ws://{host}:{port}...")
|
|
async with websockets.serve(websocket_handler, host, port):
|
|
await asyncio.Future()
|
|
|
|
|
|
|
|
# --- LOCAL AUDIO DEVICE STREAM MODE ---
|
|
def run_local_device_mode(model_name, f0_up_key, f0_method, device, input_device, output_device, chunk_size):
|
|
import sounddevice as sd
|
|
|
|
logger.info("Starting Local Audio Hardware Stream Mode...")
|
|
|
|
rvc = RealtimeVoiceChanger()
|
|
rvc.load_model(model_name, device)
|
|
|
|
if input_device is None:
|
|
input_device = sd.default.device[0]
|
|
if output_device is None:
|
|
output_device = sd.default.device[1]
|
|
|
|
input_info = sd.query_devices(input_device)
|
|
output_info = sd.query_devices(output_device)
|
|
|
|
input_sr = int(input_info["default_samplerate"])
|
|
target_sr = rvc.target_sr
|
|
|
|
logger.info(f"Input Device: {input_info['name']} (Sample Rate: {input_sr} Hz)")
|
|
logger.info(f"Output Device: {output_info['name']} (Sample Rate: {target_sr} Hz)")
|
|
|
|
rvc.set_config({
|
|
"f0_up_key": f0_up_key,
|
|
"f0_method": f0_method,
|
|
"input_sr": input_sr,
|
|
"device": device,
|
|
"model_name": model_name,
|
|
"noise_gate": -40.0,
|
|
"input_gain": 1.0,
|
|
"output_gain": 1.0
|
|
})
|
|
|
|
def audio_callback(indata, outdata, frames, time_info, status):
|
|
if status:
|
|
logger.warning(f"Audio Callback Status: {status}")
|
|
|
|
raw_chunk = indata[:, 0].copy()
|
|
output_chunk = rvc.process_audio_chunk(raw_chunk)
|
|
|
|
if len(output_chunk) < frames:
|
|
outdata[:, 0] = np.pad(output_chunk, (0, frames - len(output_chunk)), "constant")
|
|
else:
|
|
outdata[:, 0] = output_chunk[:frames]
|
|
|
|
try:
|
|
stream = sd.Stream(
|
|
device=(input_device, output_device),
|
|
samplerate=target_sr,
|
|
blocksize=chunk_size,
|
|
channels=1,
|
|
dtype="float32",
|
|
callback=audio_callback
|
|
)
|
|
with stream:
|
|
logger.info("Real-Time Sounddevice Stream active! Press Ctrl+C to stop.")
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("Local stream stopped by user")
|
|
except Exception as e:
|
|
logger.error(f"Local stream error: {e}")
|
|
traceback.print_exc()
|
|
|
|
# --- MAIN ---
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="High-Performance Real-Time RVC ONNX Server")
|
|
parser.add_argument("--mode", type=str, default="websocket", choices=["websocket", "device"], help="Server running mode")
|
|
parser.add_argument("--host", type=str, default="127.0.0.1", help="WebSocket host")
|
|
parser.add_argument("--port", type=int, default=8765, help="WebSocket port")
|
|
parser.add_argument("--model", type=str, default="", help="RVC Model folder name inside weights/")
|
|
parser.add_argument("--transpose", type=int, default=0, help="Pitch shift in semitones (transpose)")
|
|
parser.add_argument("--f0_method", type=str, default="pm", choices=["pm", "harvest", "dio", "rmvpe"], help="Pitch extraction method")
|
|
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "dml"], help="Execution provider")
|
|
parser.add_argument("--input_device", type=int, default=None, help="Input device ID (for device mode)")
|
|
parser.add_argument("--output_device", type=int, default=None, help="Output device ID (for device mode)")
|
|
parser.add_argument("--chunk_size", type=int, default=2048, help="Audio block size in samples")
|
|
|
|
args = parser.parse_args()
|
|
|
|
model_name = args.model
|
|
if not model_name:
|
|
models = get_onnx_models()
|
|
if models:
|
|
model_name = models[0]
|
|
logger.info(f"Auto-selected model: {model_name}")
|
|
else:
|
|
logger.error("No models found in weights/ directory. Please export a model first.")
|
|
sys.exit(1)
|
|
|
|
if args.mode == "websocket":
|
|
# Start the WebSocket server on the main event loop
|
|
try:
|
|
asyncio.run(start_websocket_server(args.host, args.port))
|
|
except KeyboardInterrupt:
|
|
logger.info("Server shut down")
|
|
elif args.mode == "device":
|
|
run_local_device_mode(
|
|
model_name=model_name,
|
|
f0_up_key=args.transpose,
|
|
f0_method=args.f0_method,
|
|
device=args.device,
|
|
input_device=args.input_device,
|
|
output_device=args.output_device,
|
|
chunk_size=args.chunk_size
|
|
)
|