| import os |
| import librosa |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import soundfile as sf |
| from glob import glob |
| from tqdm import tqdm |
| from os.path import basename, join, exists |
| from vq.codec_encoder import CodecEncoder |
| |
| from vq.codec_decoder_vocos import CodecDecoderVocos |
| from argparse import ArgumentParser |
| from time import time |
| from transformers import AutoModel, AutoFeatureExtractor, Wav2Vec2BertModel |
| import torch.nn as nn |
| from vq.module import SemanticDecoder, SemanticEncoder |
| from torch.utils.data import Dataset, DataLoader |
| from typing import List, Tuple |
| from collections import OrderedDict |
| import torchaudio |
| from torchaudio.transforms import Resample |
| import pandas as pd |
| import numpy as np |
|
|
| def pad_audio_batch(batch): |
| audio_list, feat_list, fname_list, audio_length = zip(*batch) |
| feat_list = list(feat_list) |
| |
| max_length_feat = max([feat.shape[1] for feat in feat_list]) |
| max_length = max_length_feat *320 |
| padded_audios = [] |
| |
| for audio in audio_list: |
| padding = max_length - audio.shape[1] |
| if padding > 0: |
| |
| padded_audio = F.pad(audio, (0, padding) , mode='constant', value=0) |
| else: |
| padded_audio = audio[:,:max_length] |
| padded_audios.append(padded_audio) |
| padded_audios = torch.stack(padded_audios) |
| padded_feat_list = [] |
| for feat in feat_list: |
| padding = max_length_feat - feat.shape[1] |
| padded_feat = F.pad(feat, (0, 0, 0, padding), mode='constant', value=0) |
| padded_feat_list.append(padded_feat) |
| |
| |
| padded_feat_list = torch.stack(padded_feat_list) |
| |
| return torch.tensor(padded_audios),torch.tensor(padded_feat_list), fname_list,audio_length |
|
|
| class WaveDataset(Dataset): |
| def __init__( |
| self, |
| file_list, |
| sampling_rate, |
| audio_norm_scale: float = 1.0, |
| root_dir: str = "" |
| ): |
| self.file_list = file_list |
| self.sampling_rate = sampling_rate |
| self.audio_norm_scale = audio_norm_scale |
| self.hop_length = 320 |
| self.root_dir = root_dir |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") |
|
|
| def __getitem__(self, index): |
| fname = self.file_list[index] |
| fname = os.path.join(self.root_dir, fname) |
|
|
| audio, sr = torchaudio.load(fname) |
| if sr != self.sampling_rate: |
| audio = Resample(sr, self.sampling_rate)(audio) |
| if self.audio_norm_scale < 1.0: |
| audio = audio * self.audio_norm_scale |
| |
| audio_pad = F.pad(audio, (160, 160)) |
|
|
| feat = self.feature_extractor( |
| audio_pad, |
| sampling_rate=self.sampling_rate, |
| return_tensors="pt" |
| ).data['input_features'] |
|
|
| |
| |
| return audio,feat, fname, int(audio.shape[1] / self.hop_length) |
| |
| def __len__(self): |
| return len(self.file_list) |
|
|
| def save_vq_code(vq_codes: torch.Tensor, wav_paths: List[str], lengths: List[int], output_dir: str ): |
| for i, wav_path in enumerate(wav_paths): |
| relative_path = os.path.relpath(wav_path, args.input_dir) |
| code_path = os.path.join(output_dir, 'vq_codes', relative_path.replace('.flac', '.npy')) |
| os.makedirs(os.path.dirname(code_path), exist_ok=True) |
| vq_code = vq_codes[i, 0,:lengths[i]] |
| np.save(code_path, vq_code.detach().cpu().numpy().astype(np.int32)) |
|
|
| if __name__ == '__main__': |
| parser = ArgumentParser() |
| parser.add_argument('--local-rank', type=int, default=0, help='Local GPU device ID') |
| parser.add_argument('--input-dir', type=str, default='/path/to/audio_folder', help='Input directory containing audio files') |
| parser.add_argument('--flist_file', type=str, default='/path/to/file.txt', help='TSV file containing paths to audio files') |
| parser.add_argument('--ckpt', type=str, default='/path/to/epoch=4-step=1400000.ckpt', help='Path to the model checkpoint') |
| parser.add_argument('--output-dir', type=str, default='/path/to/saving_code_folder', help='Output directory for saving audio files') |
| parser.add_argument('--batch_size', type=int, default=6, help='Batch size for processing') |
| parser.add_argument('--num_workers', type=int, default=4, help='Number of worker threads for the DataLoader') |
| |
| device_id = int(os.getenv('LOCAL_RANK', 0)) |
| args = parser.parse_args() |
| sr = 16000 |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| print(f'loading codec checkpoint from {args.ckpt}') |
| ckpt = torch.load(args.ckpt, map_location='cpu') |
| ckpt = ckpt['state_dict'] |
|
|
| filtered_state_dict_codec = OrderedDict() |
| filtered_state_dict_semantic_encoder = OrderedDict() |
| filtered_state_dict_gen = OrderedDict() |
| filtered_state_dict_fc_post_a = OrderedDict() |
| filtered_state_dict_fc_prior = OrderedDict() |
|
|
| for key, value in ckpt.items(): |
| if key.startswith('CodecEnc.'): |
| new_key = key[len('CodecEnc.'):] |
| filtered_state_dict_codec[new_key] = value |
| elif key.startswith('generator.'): |
| new_key = key[len('generator.'):] |
| filtered_state_dict_gen[new_key] = value |
| elif key.startswith('fc_post_a.'): |
| new_key = key[len('fc_post_a.'):] |
| filtered_state_dict_fc_post_a[new_key] = value |
| elif key.startswith('SemanticEncoder_module.'): |
| new_key = key[len('SemanticEncoder_module.'):] |
| filtered_state_dict_semantic_encoder[new_key] = value |
| elif key.startswith('fc_prior.'): |
| new_key = key[len('fc_prior.'):] |
| filtered_state_dict_fc_prior[new_key] = value |
|
|
| semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True) |
| semantic_model.eval() |
|
|
| SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024) |
| SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder) |
| SemanticEncoder_module.eval() |
|
|
| encoder = CodecEncoder() |
| encoder.load_state_dict(filtered_state_dict_codec) |
| encoder.eval() |
|
|
| decoder = CodecDecoderVocos() |
| decoder.load_state_dict(filtered_state_dict_gen) |
| decoder.eval() |
|
|
| fc_post_a = nn.Linear(2048, 1024) |
| fc_post_a.load_state_dict(filtered_state_dict_fc_post_a) |
| fc_post_a.eval() |
|
|
| fc_prior = nn.Linear(2048, 2048) |
| fc_prior.load_state_dict(filtered_state_dict_fc_prior) |
| fc_prior.eval() |
|
|
|
|
| device = torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu') |
| semantic_model.to(device) |
| SemanticEncoder_module.to(device) |
| encoder.to(device) |
| decoder.to(device) |
| fc_post_a.to(device) |
| fc_prior.to(device) |
|
|
| |
| df = pd.read_csv(args.flist_file, sep='\t', header=None, names=['filename', 'duration'], skiprows=1) |
| file_list = df['filename'].tolist() |
| |
| |
|
|
| split_file_lists = np.array_split(file_list, 8) |
|
|
| |
| device_id = device_id |
| current_file_list = split_file_lists[device_id] |
|
|
| dataset = WaveDataset(file_list=current_file_list, sampling_rate=sr, root_dir=args.input_dir) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=pad_audio_batch |
| ) |
|
|
| st = time() |
| for batch in tqdm(dataloader, desc="processing"): |
| wavs,feats,wav_paths, lengths = batch |
| wavs = wavs.to(device) |
|
|
|
|
| with torch.no_grad(): |
| |
| vq_emb = encoder(wavs ) |
| vq_emb = vq_emb.transpose(1, 2) |
|
|
| |
| semantic_target = semantic_model(feats[:,0,:,:].to(device)) |
| semantic_target = semantic_target.hidden_states[16] |
| semantic_target = semantic_target.transpose(1, 2) |
| semantic_target = SemanticEncoder_module(semantic_target) |
|
|
| vq_emb = torch.cat([semantic_target, vq_emb], dim=1) |
| vq_emb = fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) |
|
|
| _, vq_code, _ = decoder(vq_emb, vq=True) |
|
|
| save_vq_code(vq_code, wav_paths, lengths, args.output_dir) |
|
|
| et = time() |
| print(f'End,time: {(et - st)/60:.2f} mins') |
|
|