| |
| """ |
| RVC AI 翻唱 - 主入口 |
| """ |
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
|
|
| |
| ROOT_DIR = Path(__file__).parent |
| sys.path.insert(0, str(ROOT_DIR)) |
|
|
| from lib.logger import log |
|
|
|
|
| def check_environment(): |
| """检查运行环境""" |
| log.header("RVC AI 翻唱系统") |
|
|
| |
| py_version = sys.version_info |
| log.info(f"Python 版本: {py_version.major}.{py_version.minor}.{py_version.micro}") |
|
|
| if py_version.major < 3 or (py_version.major == 3 and py_version.minor < 8): |
| log.warning("建议使用 Python 3.8 或更高版本") |
|
|
| |
| try: |
| import torch |
| log.info(f"PyTorch 版本: {torch.__version__}") |
|
|
| from lib.device import get_device_info, _is_rocm, _has_xpu, _has_directml, _has_mps |
| info = get_device_info() |
| log.info(f"可用加速后端: {', '.join(info['backends'])}") |
|
|
| if torch.cuda.is_available(): |
| backend = "ROCm" if _is_rocm() else "CUDA" |
| log.info(f"{backend} 版本: {torch.version.hip if _is_rocm() else torch.version.cuda}") |
| log.info(f"GPU: {torch.cuda.get_device_name(0)}") |
| elif _has_xpu(): |
| log.info(f"Intel GPU: {torch.xpu.get_device_name(0)}") |
| elif _has_directml(): |
| import torch_directml |
| log.info(f"DirectML 设备: {torch_directml.device_name(0)}") |
| elif _has_mps(): |
| log.info("Apple MPS 加速可用") |
| else: |
| log.warning("未检测到 GPU 加速,将使用 CPU") |
| except ImportError: |
| log.error("未安装 PyTorch") |
| return False |
|
|
| return True |
|
|
|
|
| def check_models(): |
| """检查必需模型""" |
| from tools.download_models import check_model, REQUIRED_MODELS |
|
|
| missing = [] |
| for name in REQUIRED_MODELS: |
| if not check_model(name): |
| missing.append(name) |
|
|
| if missing: |
| log.warning(f"缺少必需模型: {', '.join(missing)}") |
| log.info("正在下载...") |
| from tools.download_models import download_required_models |
| if not download_required_models(): |
| log.error("模型下载失败,请检查网络连接") |
| return False |
|
|
| return True |
|
|
|
|
| def main(): |
| """主函数""" |
| parser = argparse.ArgumentParser(description="RVC AI 翻唱系统") |
| parser.add_argument( |
| "--host", |
| type=str, |
| default="127.0.0.1", |
| help="服务器地址 (默认: 127.0.0.1)" |
| ) |
| parser.add_argument( |
| "--port", |
| type=int, |
| default=7860, |
| help="服务器端口 (默认: 7860)" |
| ) |
| parser.add_argument( |
| "--share", |
| action="store_true", |
| help="创建公共链接" |
| ) |
| parser.add_argument( |
| "--skip-check", |
| action="store_true", |
| help="跳过环境检查" |
| ) |
| parser.add_argument( |
| "--download-models", |
| action="store_true", |
| help="仅下载模型" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.download_models: |
| from tools.download_models import download_all_models |
| download_all_models() |
| return |
|
|
| |
| if not args.skip_check: |
| if not check_environment(): |
| sys.exit(1) |
|
|
| |
| if not check_models(): |
| log.info("提示: 可以使用 --skip-check 跳过检查") |
| sys.exit(1) |
|
|
| |
| log.info(f"启动 Gradio 界面: http://{args.host}:{args.port}") |
| from ui.app import launch |
| launch(host=args.host, port=args.port, share=args.share) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|