Spaces:
Build error
Build error
| import os | |
| import tempfile | |
| import unittest | |
| from finetrainers.data import ( | |
| InMemoryDistributedDataPreprocessor, | |
| PrecomputedDistributedDataPreprocessor, | |
| VideoCaptionFilePairDataset, | |
| initialize_preprocessor, | |
| wrap_iterable_dataset_for_preprocessing, | |
| ) | |
| from finetrainers.data.precomputation import PRECOMPUTED_DATA_DIR | |
| from finetrainers.utils import find_files | |
| from .utils import create_dummy_directory_structure | |
| class PreprocessorFastTests(unittest.TestCase): | |
| def setUp(self): | |
| self.rank = 0 | |
| self.world_size = 1 | |
| self.num_items = 3 | |
| self.processor_fn = { | |
| "latent": self._latent_processor_fn, | |
| "condition": self._condition_processor_fn, | |
| } | |
| self.save_dir = tempfile.TemporaryDirectory() | |
| directory_structure = [ | |
| "0.mp4", | |
| "1.mp4", | |
| "2.mp4", | |
| "0.txt", | |
| "1.txt", | |
| "2.txt", | |
| ] | |
| create_dummy_directory_structure( | |
| directory_structure, self.save_dir, self.num_items, "a cat ruling the world", "mp4" | |
| ) | |
| dataset = VideoCaptionFilePairDataset(self.save_dir.name, infinite=True) | |
| dataset = wrap_iterable_dataset_for_preprocessing( | |
| dataset, | |
| dataset_type="video", | |
| config={ | |
| "video_resolution_buckets": [[2, 32, 32]], | |
| "reshape_mode": "bicubic", | |
| }, | |
| ) | |
| self.dataset = dataset | |
| def tearDown(self): | |
| self.save_dir.cleanup() | |
| def _latent_processor_fn(**data): | |
| video = data["video"] | |
| video = video[:, :, :16, :16] | |
| data["video"] = video | |
| return data | |
| def _condition_processor_fn(**data): | |
| caption = data["caption"] | |
| caption = caption + " surrounded by mystical aura" | |
| data["caption"] = caption | |
| return data | |
| def test_initialize_preprocessor(self): | |
| preprocessor = initialize_preprocessor( | |
| self.rank, | |
| self.world_size, | |
| self.num_items, | |
| self.processor_fn, | |
| self.save_dir.name, | |
| enable_precomputation=False, | |
| ) | |
| self.assertIsInstance(preprocessor, InMemoryDistributedDataPreprocessor) | |
| preprocessor = initialize_preprocessor( | |
| self.rank, | |
| self.world_size, | |
| self.num_items, | |
| self.processor_fn, | |
| self.save_dir.name, | |
| enable_precomputation=True, | |
| ) | |
| self.assertIsInstance(preprocessor, PrecomputedDistributedDataPreprocessor) | |
| def test_in_memory_preprocessor_consume(self): | |
| data_iterator = iter(self.dataset) | |
| preprocessor = initialize_preprocessor( | |
| self.rank, | |
| self.world_size, | |
| self.num_items, | |
| self.processor_fn, | |
| self.save_dir.name, | |
| enable_precomputation=False, | |
| ) | |
| condition_iterator = preprocessor.consume( | |
| "condition", components={}, data_iterator=data_iterator, cache_samples=True | |
| ) | |
| latent_iterator = preprocessor.consume( | |
| "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True | |
| ) | |
| self.assertFalse(preprocessor.requires_data) | |
| for _ in range(self.num_items): | |
| condition_item = next(condition_iterator) | |
| latent_item = next(latent_iterator) | |
| self.assertIn("caption", condition_item) | |
| self.assertIn("video", latent_item) | |
| self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") | |
| self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) | |
| self.assertTrue(preprocessor.requires_data) | |
| def test_in_memory_preprocessor_consume_once(self): | |
| data_iterator = iter(self.dataset) | |
| preprocessor = initialize_preprocessor( | |
| self.rank, | |
| self.world_size, | |
| self.num_items, | |
| self.processor_fn, | |
| self.save_dir.name, | |
| enable_precomputation=False, | |
| ) | |
| condition_iterator = preprocessor.consume_once( | |
| "condition", components={}, data_iterator=data_iterator, cache_samples=True | |
| ) | |
| latent_iterator = preprocessor.consume_once( | |
| "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True | |
| ) | |
| self.assertFalse(preprocessor.requires_data) | |
| for _ in range(self.num_items): | |
| condition_item = next(condition_iterator) | |
| latent_item = next(latent_iterator) | |
| self.assertIn("caption", condition_item) | |
| self.assertIn("video", latent_item) | |
| self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") | |
| self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) | |
| self.assertFalse(preprocessor.requires_data) | |
| def test_precomputed_preprocessor_consume(self): | |
| data_iterator = iter(self.dataset) | |
| preprocessor = initialize_preprocessor( | |
| self.rank, | |
| self.world_size, | |
| self.num_items, | |
| self.processor_fn, | |
| self.save_dir.name, | |
| enable_precomputation=True, | |
| ) | |
| condition_iterator = preprocessor.consume( | |
| "condition", components={}, data_iterator=data_iterator, cache_samples=True | |
| ) | |
| latent_iterator = preprocessor.consume( | |
| "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True | |
| ) | |
| precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) | |
| condition_file_list = find_files(precomputed_data_dir, "condition-*") | |
| latent_file_list = find_files(precomputed_data_dir, "latent-*") | |
| self.assertEqual(len(condition_file_list), 3) | |
| self.assertEqual(len(latent_file_list), 3) | |
| self.assertFalse(preprocessor.requires_data) | |
| for _ in range(self.num_items): | |
| condition_item = next(condition_iterator) | |
| latent_item = next(latent_iterator) | |
| self.assertIn("caption", condition_item) | |
| self.assertIn("video", latent_item) | |
| self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") | |
| self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) | |
| self.assertTrue(preprocessor.requires_data) | |
| def test_precomputed_preprocessor_consume_once(self): | |
| data_iterator = iter(self.dataset) | |
| preprocessor = initialize_preprocessor( | |
| self.rank, | |
| self.world_size, | |
| self.num_items, | |
| self.processor_fn, | |
| self.save_dir.name, | |
| enable_precomputation=True, | |
| ) | |
| condition_iterator = preprocessor.consume_once( | |
| "condition", components={}, data_iterator=data_iterator, cache_samples=True | |
| ) | |
| latent_iterator = preprocessor.consume_once( | |
| "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True | |
| ) | |
| precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) | |
| condition_file_list = find_files(precomputed_data_dir, "condition-*") | |
| latent_file_list = find_files(precomputed_data_dir, "latent-*") | |
| self.assertEqual(len(condition_file_list), 3) | |
| self.assertEqual(len(latent_file_list), 3) | |
| self.assertFalse(preprocessor.requires_data) | |
| for _ in range(self.num_items): | |
| condition_item = next(condition_iterator) | |
| latent_item = next(latent_iterator) | |
| self.assertIn("caption", condition_item) | |
| self.assertIn("video", latent_item) | |
| self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") | |
| self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) | |
| self.assertFalse(preprocessor.requires_data) | |