initial commit (clean, no models)
This commit is contained in:
@@ -0,0 +1,768 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import json
|
||||
|
||||
# --- FIX FOR PATH COLLISION BETWEEN modules.py AND modules/ DIRECTORY ---
|
||||
# Kita memaksa Python untuk mendaftarkan 'lib.infer_pack.modules' sebagai package directory
|
||||
# alih-alih file-module, sehingga sub-impor (seperti F0Predictor) dapat dimuat dengan sukses
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
modules_path = os.path.join(base_dir, "lib", "infer_pack", "modules")
|
||||
if os.path.isdir(modules_path):
|
||||
modules_pkg = types.ModuleType("lib.infer_pack.modules")
|
||||
modules_pkg.__path__ = [modules_path]
|
||||
modules_pkg.__file__ = os.path.join(modules_path, "__init__.py")
|
||||
sys.modules["lib.infer_pack.modules"] = modules_pkg
|
||||
# ------------------------------------------------------------------------
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import argparse
|
||||
import threading
|
||||
import webbrowser
|
||||
from http.server import SimpleHTTPRequestHandler
|
||||
import socketserver
|
||||
import numpy as np
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
# Add parent directories to sys.path so we can import lib
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from lib.infer_pack.onnx_inference import OnnxRVC, get_f0_predictor
|
||||
|
||||
# Set logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
logger = logging.getLogger("RVC-Realtime-Server")
|
||||
|
||||
# Thread pool for audio processing (1 worker = sequential, but non-blocking to event loop)
|
||||
import concurrent.futures
|
||||
_audio_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="rvc-audio")
|
||||
|
||||
# Global instances cache
|
||||
current_rvc_onnx = None
|
||||
current_model_key = None
|
||||
model_root = "weights"
|
||||
pretrained_root = "pretrained"
|
||||
|
||||
# Patch torch.load to default to weights_only=False for compatibility
|
||||
original_load = torch.load
|
||||
def patched_load(*args, **kwargs):
|
||||
if "weights_only" not in kwargs:
|
||||
kwargs["weights_only"] = False
|
||||
return original_load(*args, **kwargs)
|
||||
torch.load = patched_load
|
||||
|
||||
def get_onnx_models():
|
||||
models = []
|
||||
if os.path.exists(model_root):
|
||||
for d in os.listdir(model_root):
|
||||
d_path = os.path.join(model_root, d)
|
||||
if os.path.isdir(d_path):
|
||||
onnx_files = [f for f in os.listdir(d_path) if f.endswith(".onnx")]
|
||||
if onnx_files:
|
||||
models.append(d)
|
||||
models.sort()
|
||||
return models
|
||||
|
||||
def get_model_metadata(model_name):
|
||||
model_dir = os.path.join(model_root, model_name)
|
||||
onnx_files = [f for f in os.listdir(model_dir) if f.endswith(".onnx")]
|
||||
pth_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
|
||||
|
||||
onnx_path = os.path.join(model_dir, onnx_files[0])
|
||||
sr = 40000 # default
|
||||
if pth_files:
|
||||
try:
|
||||
pth_path = os.path.join(model_dir, pth_files[0])
|
||||
cpt = torch.load(pth_path, map_location="cpu")
|
||||
sr = cpt["config"][-1]
|
||||
logger.info(f"Detected sample rate from .pth: {sr} Hz")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load sample rate from .pth, using default 40000: {e}")
|
||||
|
||||
version = "v2"
|
||||
vec_path = "vec-768-layer-12"
|
||||
try:
|
||||
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
|
||||
feat_dim = sess.get_inputs()[0].shape[2]
|
||||
if feat_dim == 256:
|
||||
version = "v1"
|
||||
vec_path = "vec-256-layer-12"
|
||||
logger.info("Detected RVC Model Version: v1 (feat_dim = 256)")
|
||||
else:
|
||||
version = "v2"
|
||||
vec_path = "vec-768-layer-12"
|
||||
logger.info("Detected RVC Model Version: v2 (feat_dim = 768)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error auto-detecting model version from ONNX: {e}")
|
||||
|
||||
return onnx_path, sr, vec_path, version
|
||||
|
||||
class RealtimeVoiceChanger:
|
||||
def __init__(self):
|
||||
self.processor = None
|
||||
self.model_name = ""
|
||||
self.f0_up_key = 0
|
||||
self.f0_method = "pm"
|
||||
self.device = "cuda"
|
||||
self.input_sr = 44100
|
||||
self.noise_gate_db = -40.0
|
||||
self.input_gain = 1.0
|
||||
self.output_gain = 1.0
|
||||
|
||||
# Audio sliding buffers
|
||||
self.input_buffer = np.zeros(0, dtype=np.float32)
|
||||
self.history_duration = 0.30 # 300ms history = enough context even at 8192 samples @ 48kHz
|
||||
self.target_sr = 40000
|
||||
self.vec_path = "vec-768-layer-12"
|
||||
self.version = "v2"
|
||||
self.f0_predictors = {} # Cache to reuse pitch predictors instead of recreating them on every chunk
|
||||
|
||||
# Server Hardware Routing properties
|
||||
self.local_stream = None
|
||||
self.routing_mode = "browser"
|
||||
self.input_device = None
|
||||
self.output_device = None
|
||||
self.chunk_size = 8192
|
||||
self.loop = None
|
||||
self.ws_client = None
|
||||
self.visualizer_queue = None
|
||||
|
||||
# High-pass filter coefficients cache
|
||||
self.hpf_b = None
|
||||
self.hpf_a = None
|
||||
|
||||
def load_model(self, model_name, device):
|
||||
global current_rvc_onnx, current_model_key
|
||||
|
||||
# Resolve the device provider
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.warning("CUDA is not available, falling back to CPU")
|
||||
device = "cpu"
|
||||
|
||||
model_key = f"{model_name}_{device}"
|
||||
if current_rvc_onnx is None or current_model_key != model_key:
|
||||
logger.info(f"Loading RVC model '{model_name}' on {device}...")
|
||||
onnx_path, sr, vec_path, version = get_model_metadata(model_name)
|
||||
|
||||
# Ensure HuBERT model exists
|
||||
full_vec_path = os.path.join(pretrained_root, f"{vec_path}.onnx")
|
||||
if not os.path.exists(full_vec_path):
|
||||
raise FileNotFoundError(f"ContentVec ONNX not found at: {full_vec_path}")
|
||||
|
||||
current_rvc_onnx = OnnxRVC(
|
||||
model_path=onnx_path,
|
||||
sr=sr,
|
||||
hop_size=512,
|
||||
vec_path=vec_path,
|
||||
device=device
|
||||
)
|
||||
current_model_key = model_key
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
self.processor = current_rvc_onnx
|
||||
self.target_sr = self.processor.sampling_rate
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
|
||||
def set_config(self, config):
|
||||
logger.info(f"Updating config: {config}")
|
||||
|
||||
# Update config fields
|
||||
self.f0_up_key = int(config.get("f0_up_key", self.f0_up_key))
|
||||
self.f0_method = config.get("f0_method", self.f0_method)
|
||||
self.input_sr = int(config.get("input_sr", self.input_sr))
|
||||
self.noise_gate_db = float(config.get("noise_gate", self.noise_gate_db))
|
||||
self.input_gain = float(config.get("input_gain", self.input_gain))
|
||||
self.output_gain = float(config.get("output_gain", self.output_gain))
|
||||
|
||||
model_name = config.get("model_name", self.model_name)
|
||||
device = config.get("device", self.device)
|
||||
|
||||
if not self.model_name or model_name != self.model_name or device != self.device:
|
||||
self.load_model(model_name, device)
|
||||
|
||||
# Reset input buffer if input samplerate changed
|
||||
history_samples = int(self.history_duration * self.input_sr)
|
||||
if len(self.input_buffer) != history_samples:
|
||||
self.input_buffer = np.zeros(history_samples, dtype=np.float32)
|
||||
|
||||
# Design a 1st order Butterworth high-pass filter at 80Hz to eliminate low-frequency static rumbling/hums
|
||||
try:
|
||||
from scipy import signal
|
||||
nyq = 0.5 * self.input_sr
|
||||
normal_cutoff = 80.0 / nyq
|
||||
self.hpf_b, self.hpf_a = signal.butter(1, normal_cutoff, btype='high', analog=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to design high-pass filter: {e}")
|
||||
self.hpf_b, self.hpf_a = None, None
|
||||
|
||||
def apply_noise_gate(self, audio):
|
||||
# Calculate RMS energy of the audio chunk
|
||||
rms = np.sqrt(np.mean(audio**2)) + 1e-9
|
||||
rms_db = 20 * np.log10(rms)
|
||||
|
||||
if rms_db < self.noise_gate_db:
|
||||
return np.zeros_like(audio)
|
||||
return audio
|
||||
|
||||
def resample(self, audio, orig_sr, target_sr):
|
||||
if orig_sr == target_sr:
|
||||
return audio
|
||||
|
||||
# Fast linear interpolation resampling
|
||||
duration = len(audio) / orig_sr
|
||||
num_target_samples = int(duration * target_sr)
|
||||
x_orig = np.linspace(0, duration, len(audio))
|
||||
x_target = np.linspace(0, duration, num_target_samples)
|
||||
return np.interp(x_target, x_orig, audio).astype(np.float32)
|
||||
|
||||
def process_audio_chunk(self, raw_chunk):
|
||||
"""
|
||||
Process a raw input Float32 PCM audio chunk in memory with sliding window.
|
||||
"""
|
||||
if self.processor is None:
|
||||
return raw_chunk
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
# 1. Apply High-pass filter (80Hz Low-cut) to eliminate low-frequency background rumbles and AC hum
|
||||
chunk = raw_chunk
|
||||
if self.hpf_b is not None and self.hpf_a is not None:
|
||||
try:
|
||||
from scipy import signal
|
||||
chunk = signal.lfilter(self.hpf_b, self.hpf_a, chunk).astype(np.float32)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 2. Apply Input Gain & Noise Gate
|
||||
chunk = chunk * self.input_gain
|
||||
chunk = self.apply_noise_gate(chunk)
|
||||
|
||||
# If chunk is pure silence from the gate, bypass inference immediately to save CPU!
|
||||
if np.max(np.abs(chunk)) < 1e-6:
|
||||
output_len = int(len(raw_chunk) * (self.target_sr / self.input_sr))
|
||||
return np.zeros(output_len, dtype=np.float32)
|
||||
|
||||
t_gate = time.time()
|
||||
|
||||
# 3. Manage Sliding Window Buffer
|
||||
self.input_buffer = np.append(self.input_buffer[len(chunk):], chunk)
|
||||
|
||||
# Append 120ms of silence at the end to push RVC convolution edge fading into the padded future (no edge distortion!)
|
||||
future_samples = int(0.12 * self.input_sr)
|
||||
full_input_audio = np.append(self.input_buffer, np.zeros(future_samples, dtype=np.float32))
|
||||
|
||||
# 4. Resample full segment to 16kHz for HuBERT and RMVPE
|
||||
wav16k = self.resample(full_input_audio, self.input_sr, 16000)
|
||||
|
||||
t_resample_in = time.time()
|
||||
|
||||
# 4. Generate RVC ONNX inputs in-memory
|
||||
hubert = self.processor.vec_model(wav16k)
|
||||
hubert = np.repeat(hubert, 2, axis=2).transpose(0, 2, 1).astype(np.float32)
|
||||
hubert_length = hubert.shape[1]
|
||||
|
||||
t_hubert = time.time()
|
||||
|
||||
# Initialize and cache pitch predictor (extract at 16kHz for 3x performance boost!)
|
||||
predictor_key = f"{self.f0_method}_{self.device}_16000"
|
||||
if predictor_key not in self.f0_predictors:
|
||||
logger.info(f"Initializing and caching 16kHz F0 Predictor '{self.f0_method}' on {self.device}...")
|
||||
hop_16k = int(self.processor.hop_size * 16000 / self.target_sr)
|
||||
self.f0_predictors[predictor_key] = get_f0_predictor(
|
||||
self.f0_method,
|
||||
hop_length=hop_16k,
|
||||
sampling_rate=16000,
|
||||
threshold=0.02,
|
||||
device=self.device,
|
||||
is_half=True if self.device == "cuda" else False
|
||||
)
|
||||
f0_predictor = self.f0_predictors[predictor_key]
|
||||
|
||||
# Calculate pitch on 16kHz audio for all methods (massive CPU speed up!)
|
||||
pitchf = f0_predictor.compute_f0(wav16k, hubert_length)
|
||||
|
||||
t_f0 = time.time()
|
||||
|
||||
# Pitch transpose
|
||||
pitchf = pitchf * (2 ** (self.f0_up_key / 12))
|
||||
|
||||
# Pitch binning for RVC
|
||||
f0_min = 50
|
||||
f0_max = 1100
|
||||
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
||||
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
||||
f0_mel = 1127 * np.log(1 + pitchf / 700)
|
||||
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
|
||||
f0_mel[f0_mel <= 1] = 1
|
||||
f0_mel[f0_mel > 255] = 255
|
||||
pitch = np.rint(f0_mel).astype(np.int64)
|
||||
|
||||
pitchf = pitchf.reshape(1, len(pitchf)).astype(np.float32)
|
||||
pitch = pitch.reshape(1, len(pitch))
|
||||
ds = np.array([0]).astype(np.int64) # sid = 0
|
||||
rnd = np.random.randn(1, 192, hubert_length).astype(np.float32)
|
||||
hubert_length_tensor = np.array([hubert_length]).astype(np.int64)
|
||||
|
||||
# 5. Run synthesis
|
||||
out_wav = self.processor.forward(hubert, hubert_length_tensor, pitch, pitchf, ds, rnd).squeeze()
|
||||
out_wav = out_wav.astype(np.float32) / 32767.0 # Normalize back to [-1.0, 1.0] float32
|
||||
|
||||
t_synth = time.time()
|
||||
|
||||
# 6. Extract only the newly converted chunk, discarding history and future padding.
|
||||
# The new chunk starts at (buffer_size - chunk_size) in the updated input_buffer.
|
||||
# Use exact ratio target_sr/input_sr for clean integer math, not out_wav/full_input_audio
|
||||
# which can drift due to future silence padding.
|
||||
sr_ratio = self.target_sr / self.input_sr
|
||||
history_in_buffer = len(self.input_buffer) - len(chunk) # samples of old audio before new chunk
|
||||
start_idx = int(history_in_buffer * sr_ratio)
|
||||
end_idx = int((history_in_buffer + len(chunk)) * sr_ratio)
|
||||
|
||||
# Safety clamp
|
||||
start_idx = max(0, min(start_idx, len(out_wav) - 1))
|
||||
end_idx = max(start_idx + 1, min(end_idx, len(out_wav)))
|
||||
|
||||
# Extract the converted chunk safely from the middle (flawless, clear, continuous voice!)
|
||||
output_chunk = out_wav[start_idx:end_idx]
|
||||
|
||||
# Resample back to target output samples to match browser playback rate perfectly
|
||||
target_chunk_len = int(len(chunk) * (self.target_sr / self.input_sr))
|
||||
output_chunk = self.resample(output_chunk, len(output_chunk), target_chunk_len)
|
||||
|
||||
# Apply output gain
|
||||
output_chunk = output_chunk * self.output_gain
|
||||
|
||||
# Ensure we don't clip
|
||||
output_chunk = np.clip(output_chunk, -1.0, 1.0)
|
||||
|
||||
t_end = time.time()
|
||||
|
||||
d_gate = (t_gate - t_start) * 1000
|
||||
d_res_in = (t_resample_in - t_gate) * 1000
|
||||
d_hubert = (t_hubert - t_resample_in) * 1000
|
||||
d_f0 = (t_f0 - t_hubert) * 1000
|
||||
d_synth = (t_synth - t_f0) * 1000
|
||||
d_res_out = (t_end - t_synth) * 1000
|
||||
t_elapsed = (t_end - t_start) * 1000
|
||||
|
||||
logger.info(
|
||||
f"Chunk Profile: total={t_elapsed:.1f}ms | gate={d_gate:.1f}ms | res_in={d_res_in:.1f}ms | "
|
||||
f"hubert={d_hubert:.1f}ms | f0={d_f0:.1f}ms | synth={d_synth:.1f}ms | res_out={d_res_out:.1f}ms"
|
||||
)
|
||||
|
||||
return output_chunk
|
||||
|
||||
def start_local_stream(self, loop, ws_client):
|
||||
import sounddevice as sd
|
||||
self.loop = loop
|
||||
self.ws_client = ws_client
|
||||
self.visualizer_queue = asyncio.Queue()
|
||||
|
||||
if self.local_stream is not None:
|
||||
self.stop_local_stream()
|
||||
|
||||
if self.input_device is None:
|
||||
self.input_device = sd.default.device[0]
|
||||
if self.output_device is None:
|
||||
self.output_device = sd.default.device[1]
|
||||
|
||||
input_info = sd.query_devices(self.input_device)
|
||||
output_info = sd.query_devices(self.output_device)
|
||||
|
||||
input_sr = int(input_info["default_samplerate"])
|
||||
logger.info(f"Starting Server Hardware Stream: Input='{input_info['name']}' ({input_sr}Hz) | Output='{output_info['name']}' ({self.target_sr}Hz)")
|
||||
|
||||
self.set_config({
|
||||
"input_sr": input_sr
|
||||
})
|
||||
|
||||
def audio_callback(indata, outdata, frames, time_info, status):
|
||||
if status:
|
||||
logger.warning(f"Hardware Audio Callback Status: {status}")
|
||||
|
||||
raw_chunk = indata[:, 0].copy()
|
||||
output_chunk = self.process_audio_chunk(raw_chunk)
|
||||
|
||||
if len(output_chunk) < frames:
|
||||
outdata[:, 0] = np.pad(output_chunk, (0, frames - len(output_chunk)), "constant")
|
||||
else:
|
||||
outdata[:, 0] = output_chunk[:frames]
|
||||
|
||||
# Send waveform chunks to WebSocket safely using loop.call_soon_threadsafe
|
||||
if self.ws_client is not None:
|
||||
loop.call_soon_threadsafe(
|
||||
self.visualizer_queue.put_nowait,
|
||||
(raw_chunk.copy(), output_chunk.copy())
|
||||
)
|
||||
|
||||
try:
|
||||
self.local_stream = sd.Stream(
|
||||
device=(self.input_device, self.output_device),
|
||||
samplerate=self.target_sr,
|
||||
blocksize=self.chunk_size,
|
||||
channels=1,
|
||||
dtype="float32",
|
||||
callback=audio_callback
|
||||
)
|
||||
self.local_stream.start()
|
||||
logger.info("Server Hardware Stream active and processing locally!")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start hardware stream: {e}")
|
||||
raise e
|
||||
|
||||
def stop_local_stream(self):
|
||||
if self.local_stream is not None:
|
||||
try:
|
||||
self.local_stream.stop()
|
||||
self.local_stream.close()
|
||||
logger.info("Server Hardware Stream stopped successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping hardware stream: {e}")
|
||||
self.local_stream = None
|
||||
self.visualizer_queue = None
|
||||
|
||||
# --- WEBSOCKET SERVER IMPLEMENTATION ---
|
||||
async def websocket_handler(websocket):
|
||||
logger.info("New WebSocket client connected")
|
||||
rvc = RealtimeVoiceChanger()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# --- Pipeline queues ---
|
||||
# input_queue : raw bytes from browser mic
|
||||
# output_queue: processed bytes ready to send back
|
||||
# maxsize=3 = ~3 chunk durations of buffer; provides backpressure without unbounded memory
|
||||
input_queue = asyncio.Queue(maxsize=3)
|
||||
output_queue = asyncio.Queue(maxsize=3)
|
||||
|
||||
# --- STAGE 1: Receiver ---
|
||||
# Reads ALL WebSocket messages immediately (never blocks on processing).
|
||||
# Handles JSON config inline; binary audio chunks go to input_queue.
|
||||
async def receiver_task():
|
||||
try:
|
||||
async for message in websocket:
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
if data.get("type") == "config":
|
||||
new_routing_mode = data.get("routing_mode", rvc.routing_mode)
|
||||
new_input_device = data.get("input_device", rvc.input_device)
|
||||
new_output_device = data.get("output_device", rvc.output_device)
|
||||
new_chunk_size = int(data.get("chunk_size", rvc.chunk_size))
|
||||
|
||||
if new_input_device is not None: rvc.input_device = int(new_input_device)
|
||||
if new_output_device is not None: rvc.output_device = int(new_output_device)
|
||||
rvc.chunk_size = new_chunk_size
|
||||
rvc.set_config(data)
|
||||
|
||||
if new_routing_mode == "hardware":
|
||||
if rvc.routing_mode != "hardware" or rvc.local_stream is None:
|
||||
rvc.routing_mode = "hardware"
|
||||
rvc.start_local_stream(loop, websocket)
|
||||
else:
|
||||
if rvc.routing_mode == "hardware":
|
||||
rvc.stop_local_stream()
|
||||
rvc.routing_mode = "browser"
|
||||
|
||||
response = {
|
||||
"type": "config_success",
|
||||
"model_name": rvc.model_name,
|
||||
"target_sr": rvc.target_sr,
|
||||
"f0_method": rvc.f0_method,
|
||||
"f0_up_key": rvc.f0_up_key,
|
||||
"device": rvc.device,
|
||||
"routing_mode": rvc.routing_mode
|
||||
}
|
||||
await websocket.send(json.dumps(response))
|
||||
except Exception as e:
|
||||
logger.error(f"Config parse error: {e}")
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
||||
|
||||
elif isinstance(message, bytes):
|
||||
if rvc.routing_mode == "hardware":
|
||||
continue
|
||||
if rvc.processor is None:
|
||||
# Prepend 0.0 processing time to echoed message
|
||||
input_chunk = np.frombuffer(message, dtype=np.float32)
|
||||
payload = np.empty(len(input_chunk) + 1, dtype=np.float32)
|
||||
payload[0] = 0.0
|
||||
payload[1:] = input_chunk
|
||||
await websocket.send(payload.tobytes())
|
||||
continue
|
||||
# Put chunk in queue — await here means we yield if queue is full (backpressure)
|
||||
await input_queue.put(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Receiver error: {e}")
|
||||
finally:
|
||||
# Signal downstream stages to stop
|
||||
await input_queue.put(None)
|
||||
|
||||
# --- STAGE 2: Processor ---
|
||||
# Pulls from input_queue, processes in thread (non-blocking to event loop), pushes to output_queue.
|
||||
# Sequential (1 executor worker) = output order matches input order.
|
||||
async def processor_task():
|
||||
try:
|
||||
while True:
|
||||
item = await input_queue.get()
|
||||
if item is None:
|
||||
break # shutdown signal
|
||||
input_chunk = np.frombuffer(item, dtype=np.float32).copy()
|
||||
t_start = time.time()
|
||||
output_chunk = await loop.run_in_executor(
|
||||
_audio_executor, rvc.process_audio_chunk, input_chunk
|
||||
)
|
||||
t_elapsed = (time.time() - t_start) * 1000
|
||||
|
||||
# Prepend the elapsed processing time to the audio chunk bytes
|
||||
payload = np.empty(len(output_chunk) + 1, dtype=np.float32)
|
||||
payload[0] = t_elapsed
|
||||
payload[1:] = output_chunk
|
||||
await output_queue.put(payload.tobytes())
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Processor error: {e}")
|
||||
finally:
|
||||
await output_queue.put(None)
|
||||
|
||||
# --- STAGE 3: Sender ---
|
||||
# Pulls processed audio from output_queue and sends to client.
|
||||
async def sender_task():
|
||||
try:
|
||||
while True:
|
||||
item = await output_queue.get()
|
||||
if item is None:
|
||||
break # shutdown signal
|
||||
await websocket.send(item)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Sender error: {e}")
|
||||
|
||||
# --- Visualizer (hardware mode only) ---
|
||||
async def visualizer_sender_loop():
|
||||
try:
|
||||
while True:
|
||||
if rvc.routing_mode == "hardware" and rvc.visualizer_queue is not None:
|
||||
try:
|
||||
raw_chunk, output_chunk = await asyncio.wait_for(
|
||||
rvc.visualizer_queue.get(), timeout=0.1
|
||||
)
|
||||
payload = {
|
||||
"type": "visualizer",
|
||||
"input": raw_chunk.tolist(),
|
||||
"output": output_chunk.tolist()
|
||||
}
|
||||
await websocket.send(json.dumps(payload))
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
await asyncio.sleep(0.05)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Visualizer sender error: {e}")
|
||||
|
||||
# --- Send device list on connect ---
|
||||
import sounddevice as sd
|
||||
devices_list = []
|
||||
try:
|
||||
for idx, d in enumerate(sd.query_devices()):
|
||||
devices_list.append({
|
||||
"id": idx,
|
||||
"name": d["name"],
|
||||
"max_input_channels": d["max_input_channels"],
|
||||
"max_output_channels": d["max_output_channels"],
|
||||
"default_samplerate": d["default_samplerate"]
|
||||
})
|
||||
default_input = sd.default.device[0]
|
||||
default_output = sd.default.device[1]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query server audio devices: {e}")
|
||||
devices_list = []
|
||||
default_input = -1
|
||||
default_output = -1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "init_devices",
|
||||
"devices": devices_list,
|
||||
"default_input": default_input,
|
||||
"default_output": default_output
|
||||
}))
|
||||
|
||||
# --- Run all pipeline stages concurrently ---
|
||||
vis_task = asyncio.create_task(visualizer_sender_loop())
|
||||
proc_task = asyncio.create_task(processor_task())
|
||||
send_task = asyncio.create_task(sender_task())
|
||||
|
||||
try:
|
||||
await receiver_task() # runs until websocket closes
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket handler error: {e}")
|
||||
finally:
|
||||
vis_task.cancel()
|
||||
proc_task.cancel()
|
||||
send_task.cancel()
|
||||
rvc.stop_local_stream()
|
||||
logger.info("WebSocket client disconnected, pipeline cleaned up.")
|
||||
|
||||
|
||||
async def start_websocket_server(host, port):
|
||||
import websockets
|
||||
logger.info(f"Starting WebSocket server on ws://{host}:{port}...")
|
||||
async with websockets.serve(websocket_handler, host, port):
|
||||
await asyncio.Future()
|
||||
|
||||
# --- HTTP STATIC FILE SERVER FOR FRONTEND ---
|
||||
def start_http_server(port, directory="frontend"):
|
||||
class MyHandler(SimpleHTTPRequestHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Force serve from directory relative to the project root
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
full_dir = os.path.join(base_dir, directory)
|
||||
super().__init__(*args, directory=full_dir, **kwargs)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
# Suppress standard logging to prevent console pollution
|
||||
pass
|
||||
|
||||
try:
|
||||
# Create a TCPServer that allows address reuse
|
||||
socketserver.TCPServer.allow_reuse_address = True
|
||||
with socketserver.TCPServer(("", port), MyHandler) as httpd:
|
||||
logger.info(f"Serving HTTP frontend on http://localhost:{port}")
|
||||
httpd.serve_forever()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start HTTP server: {e}")
|
||||
|
||||
# --- LOCAL AUDIO DEVICE STREAM MODE ---
|
||||
def run_local_device_mode(model_name, f0_up_key, f0_method, device, input_device, output_device, chunk_size):
|
||||
import sounddevice as sd
|
||||
|
||||
logger.info("Starting Local Audio Hardware Stream Mode...")
|
||||
|
||||
rvc = RealtimeVoiceChanger()
|
||||
rvc.load_model(model_name, device)
|
||||
|
||||
if input_device is None:
|
||||
input_device = sd.default.device[0]
|
||||
if output_device is None:
|
||||
output_device = sd.default.device[1]
|
||||
|
||||
input_info = sd.query_devices(input_device)
|
||||
output_info = sd.query_devices(output_device)
|
||||
|
||||
input_sr = int(input_info["default_samplerate"])
|
||||
target_sr = rvc.target_sr
|
||||
|
||||
logger.info(f"Input Device: {input_info['name']} (Sample Rate: {input_sr} Hz)")
|
||||
logger.info(f"Output Device: {output_info['name']} (Sample Rate: {target_sr} Hz)")
|
||||
|
||||
rvc.set_config({
|
||||
"f0_up_key": f0_up_key,
|
||||
"f0_method": f0_method,
|
||||
"input_sr": input_sr,
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"noise_gate": -40.0,
|
||||
"input_gain": 1.0,
|
||||
"output_gain": 1.0
|
||||
})
|
||||
|
||||
def audio_callback(indata, outdata, frames, time_info, status):
|
||||
if status:
|
||||
logger.warning(f"Audio Callback Status: {status}")
|
||||
|
||||
raw_chunk = indata[:, 0].copy()
|
||||
output_chunk = rvc.process_audio_chunk(raw_chunk)
|
||||
|
||||
if len(output_chunk) < frames:
|
||||
outdata[:, 0] = np.pad(output_chunk, (0, frames - len(output_chunk)), "constant")
|
||||
else:
|
||||
outdata[:, 0] = output_chunk[:frames]
|
||||
|
||||
try:
|
||||
stream = sd.Stream(
|
||||
device=(input_device, output_device),
|
||||
samplerate=target_sr,
|
||||
blocksize=chunk_size,
|
||||
channels=1,
|
||||
dtype="float32",
|
||||
callback=audio_callback
|
||||
)
|
||||
with stream:
|
||||
logger.info("Real-Time Sounddevice Stream active! Press Ctrl+C to stop.")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Local stream stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Local stream error: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# --- MAIN ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="High-Performance Real-Time RVC ONNX Server")
|
||||
parser.add_argument("--mode", type=str, default="websocket", choices=["websocket", "device"], help="Server running mode")
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1", help="WebSocket host")
|
||||
parser.add_argument("--port", type=int, default=8765, help="WebSocket port")
|
||||
parser.add_argument("--http_port", type=int, default=8000, help="HTTP static server port for Web UI")
|
||||
parser.add_argument("--model", type=str, default="", help="RVC Model folder name inside weights/")
|
||||
parser.add_argument("--transpose", type=int, default=0, help="Pitch shift in semitones (transpose)")
|
||||
parser.add_argument("--f0_method", type=str, default="pm", choices=["pm", "harvest", "dio", "rmvpe"], help="Pitch extraction method")
|
||||
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "dml"], help="Execution provider")
|
||||
parser.add_argument("--input_device", type=int, default=None, help="Input device ID (for device mode)")
|
||||
parser.add_argument("--output_device", type=int, default=None, help="Output device ID (for device mode)")
|
||||
parser.add_argument("--chunk_size", type=int, default=2048, help="Audio block size in samples")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_name = args.model
|
||||
if not model_name:
|
||||
models = get_onnx_models()
|
||||
if models:
|
||||
model_name = models[0]
|
||||
logger.info(f"Auto-selected model: {model_name}")
|
||||
else:
|
||||
logger.error("No models found in weights/ directory. Please export a model first.")
|
||||
sys.exit(1)
|
||||
|
||||
if args.mode == "websocket":
|
||||
# 1. Start HTTP Server in a background thread to serve the frontend!
|
||||
http_thread = threading.Thread(
|
||||
target=start_http_server,
|
||||
args=(args.http_port, "frontend"),
|
||||
daemon=True
|
||||
)
|
||||
http_thread.start()
|
||||
|
||||
# 2. Automatically open the Web UI in the default browser!
|
||||
web_ui_url = f"http://127.0.0.1:{args.http_port}"
|
||||
logger.info(f"Automatically launching Web UI at {web_ui_url} in browser...")
|
||||
|
||||
# We give it a tiny delay to ensure the HTTP server socket is open
|
||||
def open_browser():
|
||||
time.sleep(0.5)
|
||||
webbrowser.open(web_ui_url)
|
||||
|
||||
browser_thread = threading.Thread(target=open_browser, daemon=True)
|
||||
browser_thread.start()
|
||||
|
||||
# 3. Start the WebSocket server on the main event loop
|
||||
try:
|
||||
asyncio.run(start_websocket_server(args.host, args.port))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shut down")
|
||||
elif args.mode == "device":
|
||||
run_local_device_mode(
|
||||
model_name=model_name,
|
||||
f0_up_key=args.transpose,
|
||||
f0_method=args.f0_method,
|
||||
device=args.device,
|
||||
input_device=args.input_device,
|
||||
output_device=args.output_device,
|
||||
chunk_size=args.chunk_size
|
||||
)
|
||||
Reference in New Issue
Block a user