Spaces:
Sleeping
Sleeping
| import io | |
| import logging | |
| import os | |
| import re | |
| import threading | |
| from datetime import datetime, timezone | |
| # Avoid invalid OMP setting from runtime environment (e.g. empty/non-numeric). | |
| _omp_threads = os.getenv("OMP_NUM_THREADS", "").strip() | |
| if not _omp_threads.isdigit() or int(_omp_threads) < 1: | |
| os.environ["OMP_NUM_THREADS"] = "8" | |
| import torch | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Request, UploadFile | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image, UnidentifiedImageError | |
| from pymongo import MongoClient | |
| from pymongo.errors import PyMongoError, ServerSelectionTimeoutError | |
| from starlette.datastructures import UploadFile as StarletteUploadFile | |
| from transformers import ( | |
| AutoModelForImageTextToText, | |
| AutoModelForSeq2SeqLM, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| ) | |
| load_dotenv() | |
| CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "vidhi0405/Qwen_I2T") | |
| SUMMARIZER_MODEL_ID = os.getenv("SUMMARIZER_MODEL_ID", "facebook/bart-large-cnn") | |
| DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| MAX_NEW_TOKENS = 120 | |
| MAX_IMAGES = 5 | |
| MONGO_URI = (os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or "").strip().strip('"').strip("'") | |
| MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech") | |
| CAPTION_PROMPT = ( | |
| "Act as a professional news reporter delivering a live on-scene report in real time. " | |
| "Speak naturally, as if you are addressing viewers who are watching this unfold right now. " | |
| "Describe the scene in 3 to 4 complete, vivid sentences. " | |
| "Mention what is happening, the surrounding environment, and the overall mood, " | |
| "and convey the urgency or emotion of the moment when appropriate." | |
| ) | |
| CAPTION_RETRY_PROMPT = ( | |
| "Describe this image in 2 to 3 complete sentences. " | |
| "Mention the main subject, action, environment, and mood." | |
| ) | |
| CAPTION_MIN_SENTENCES = 3 | |
| CAPTION_MAX_SENTENCES = 4 | |
| PROCESSOR_MAX_LENGTH = 8192 | |
| logger = logging.getLogger(__name__) | |
| def ok(message: str, data): | |
| return JSONResponse( | |
| status_code=200, | |
| content={"success": True, "message": message, "data": data}, | |
| ) | |
| def fail(message: str, status_code: int = 400): | |
| return JSONResponse( | |
| status_code=status_code, | |
| content={"success": False, "message": message, "data": None}, | |
| ) | |
| class AppError(Exception): | |
| def __init__(self, message: str, status_code: int = 400): | |
| super().__init__(message) | |
| self.message = message | |
| self.status_code = status_code | |
| torch.set_num_threads(8) | |
| _caption_model = None | |
| _caption_processor = None | |
| _caption_lock = threading.Lock() | |
| _caption_force_cpu = False | |
| _summarizer_model = None | |
| _summarizer_tokenizer = None | |
| _summarizer_lock = threading.Lock() | |
| app = FastAPI(title="Image to Text API") | |
| mongo_client = None | |
| mongo_db = None | |
| caption_collection = None | |
| db_init_error = None | |
| if not MONGO_URI: | |
| db_init_error = "MONGO_URI (or MONGODB_URI) is not set." | |
| else: | |
| try: | |
| mongo_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000) | |
| mongo_client.admin.command("ping") | |
| mongo_db = mongo_client[MONGO_DB_NAME] | |
| caption_collection = mongo_db["captions"] | |
| except ServerSelectionTimeoutError: | |
| db_init_error = "Unable to connect to MongoDB (timeout)." | |
| except PyMongoError as exc: | |
| db_init_error = "Unable to initialize MongoDB: {}".format(exc) | |
| async def root(): | |
| return { | |
| "success": True, | |
| "message": "Use POST /generate-caption with form-data key 'file' or 'files' (up to 5 images).", | |
| "data": None, | |
| } | |
| async def health(): | |
| if db_init_error: | |
| return { | |
| "success": False, | |
| "message": db_init_error, | |
| "data": { | |
| "caption_model_id": CAPTION_MODEL_ID, | |
| "summarizer_model_id": SUMMARIZER_MODEL_ID, | |
| }, | |
| } | |
| return { | |
| "success": True, | |
| "message": "ok", | |
| "data": { | |
| "caption_model_id": CAPTION_MODEL_ID, | |
| "summarizer_model_id": SUMMARIZER_MODEL_ID, | |
| }, | |
| } | |
| async def preload_runtime_models(): | |
| if os.getenv("PRELOAD_MODELS", "1").strip().lower() in {"0", "false", "no"}: | |
| logger.info("Model preloading disabled via PRELOAD_MODELS.") | |
| return | |
| try: | |
| _get_caption_runtime() | |
| logger.info("Caption model preloaded successfully.") | |
| except Exception as exc: | |
| logger.warning("Caption model preload failed: %s", exc) | |
| try: | |
| _get_summarizer_runtime() | |
| logger.info("Summarizer model preloaded successfully.") | |
| except Exception as exc: | |
| logger.warning("Summarizer model preload failed: %s", exc) | |
| async def app_error_handler(_, exc: AppError): | |
| return fail(exc.message, exc.status_code) | |
| async def validation_error_handler(_, exc: RequestValidationError): | |
| return fail("Invalid request payload.", 422) | |
| async def unhandled_error_handler(_, exc: Exception): | |
| logger.exception("Unhandled server error: %s", exc) | |
| return fail("Internal server error.", 500) | |
| def _ensure_db_ready(): | |
| if db_init_error: | |
| raise AppError(db_init_error, 503) | |
| def _finalize_caption(raw_text: str, max_sentences: int = CAPTION_MAX_SENTENCES) -> str: | |
| text = " ".join(raw_text.split()).strip() | |
| if not text: | |
| return "" | |
| sentences = re.findall(r"[^.!?]+[.!?]", text) | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| if len(sentences) >= CAPTION_MIN_SENTENCES: | |
| return " ".join(sentences[:max_sentences]).strip() | |
| if text and text[-1] not in ".!?": | |
| text = re.sub(r"[,:;\-]\s*[^,:;\-]*$", "", text).strip() | |
| return text | |
| def _get_caption_runtime(): | |
| global _caption_model, _caption_processor, _caption_force_cpu | |
| if _caption_model is not None and _caption_processor is not None: | |
| return _caption_model, _caption_processor | |
| with _caption_lock: | |
| if _caption_model is None or _caption_processor is None: | |
| device = "cpu" if _caption_force_cpu else DEVICE | |
| dtype = torch.float32 if device == "cpu" else DTYPE | |
| try: | |
| loaded_model = AutoModelForImageTextToText.from_pretrained( | |
| CAPTION_MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| loaded_processor = AutoProcessor.from_pretrained( | |
| CAPTION_MODEL_ID, | |
| trust_remote_code=True, | |
| ) | |
| except Exception as exc: | |
| raise AppError("Failed to load caption model.", 503) from exc | |
| loaded_model.eval() | |
| _caption_model = loaded_model | |
| _caption_processor = loaded_processor | |
| return _caption_model, _caption_processor | |
| def _get_summarizer_runtime(): | |
| global _summarizer_model, _summarizer_tokenizer | |
| if _summarizer_model is not None and _summarizer_tokenizer is not None: | |
| return _summarizer_model, _summarizer_tokenizer | |
| with _summarizer_lock: | |
| if _summarizer_model is None or _summarizer_tokenizer is None: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL_ID) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL_ID, torch_dtype=DTYPE).to(DEVICE) | |
| except Exception as exc: | |
| raise AppError("Failed to load summarization model.", 503) from exc | |
| model.eval() | |
| _summarizer_tokenizer = tokenizer | |
| _summarizer_model = model | |
| return _summarizer_model, _summarizer_tokenizer | |
| def summarize_captions(captions: list[str]) -> str: | |
| if not captions: | |
| return "" | |
| if len(captions) == 1: | |
| return captions[0] | |
| model, tokenizer = _get_summarizer_runtime() | |
| combined = " ".join(c.strip() for c in captions if c and c.strip()) | |
| if not combined: | |
| return "" | |
| try: | |
| inputs = tokenizer( | |
| combined, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_length=512, | |
| min_length=100, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True, | |
| ) | |
| summary = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| except Exception as exc: | |
| raise AppError("Failed to summarize captions.", 500) from exc | |
| return _finalize_caption(summary, max_sentences=10) | |
| def generate_caption_text(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str: | |
| runtime_model, runtime_processor = _get_caption_runtime() | |
| model_device = str(next(runtime_model.parameters()).device) | |
| def _build_inputs(prompt: str): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| text = runtime_processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| return runtime_processor( | |
| text=text, | |
| images=image, | |
| return_tensors="pt", | |
| truncation=False, | |
| max_length=PROCESSOR_MAX_LENGTH, | |
| ) | |
| try: | |
| inputs = _build_inputs(prompt) | |
| except Exception as exc: | |
| if "Mismatch in `image` token count" not in str(exc): | |
| raise AppError("Failed to preprocess image for captioning.", 422) from exc | |
| try: | |
| inputs = _build_inputs(CAPTION_RETRY_PROMPT) | |
| except Exception as retry_exc: | |
| raise AppError("Failed to preprocess image during retry.", 422) from retry_exc | |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} | |
| try: | |
| with torch.no_grad(): | |
| outputs = runtime_model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.7, | |
| repetition_penalty=1.2, | |
| ) | |
| decoded = runtime_processor.decode(outputs[0], skip_special_tokens=True).strip() | |
| except Exception as exc: | |
| raise AppError("Caption generation failed.", 500) from exc | |
| caption = decoded.split("assistant")[-1].lstrip(":\n ").strip() | |
| return _finalize_caption(caption) | |
| def generate_caption_text_safe(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str: | |
| global _caption_model, _caption_processor, _caption_force_cpu | |
| try: | |
| return generate_caption_text(image, prompt) | |
| except Exception as exc: | |
| msg = str(exc) | |
| if "CUDA error" not in msg and "device-side assert" not in msg: | |
| raise | |
| with _caption_lock: | |
| _caption_force_cpu = True | |
| _caption_model = None | |
| _caption_processor = None | |
| if torch.cuda.is_available(): | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| return generate_caption_text(image, prompt) | |
| def insert_record(collection, payload: dict) -> str: | |
| try: | |
| result = collection.insert_one(payload) | |
| return str(result.inserted_id) | |
| except PyMongoError as exc: | |
| raise AppError("MongoDB insert failed.", 503) from exc | |
| async def _parse_images(request: Request) -> list[tuple[str, Image.Image]]: | |
| try: | |
| form = await request.form() | |
| except Exception as exc: | |
| raise AppError("Invalid request payload.", 422) from exc | |
| uploads: list[UploadFile | StarletteUploadFile] = [] | |
| for key in ("files", "files[]", "file"): | |
| for value in form.getlist(key): | |
| if isinstance(value, (UploadFile, StarletteUploadFile)): | |
| uploads.append(value) | |
| # Fallback for clients that send non-standard multipart keys. | |
| if not uploads: | |
| for _, value in form.multi_items(): | |
| if isinstance(value, (UploadFile, StarletteUploadFile)): | |
| uploads.append(value) | |
| if not uploads: | |
| raise AppError("At least one image is required.", 400) | |
| if len(uploads) > MAX_IMAGES: | |
| raise AppError("You can upload a maximum of 5 images.", 400) | |
| parsed_images = [] | |
| for i, upload in enumerate(uploads): | |
| if upload.content_type and not upload.content_type.startswith("image/"): | |
| raise AppError("All uploaded files must be images.", 400) | |
| try: | |
| file_bytes = await upload.read() | |
| except Exception as exc: | |
| raise AppError("Failed to read uploaded file content.", 400) from exc | |
| if not file_bytes: | |
| raise AppError("One of the uploaded images is empty.", 400) | |
| try: | |
| image = Image.open(io.BytesIO(file_bytes)).convert("RGB") | |
| except UnidentifiedImageError as exc: | |
| raise AppError("One of the uploaded files is not a valid image.", 400) from exc | |
| except OSError as exc: | |
| raise AppError("Unable to read one of the uploaded images.", 400) from exc | |
| filename = upload.filename or f"image_{i+1}" | |
| parsed_images.append((filename, image)) | |
| return parsed_images | |
| async def generate_caption(request: Request): | |
| _ensure_db_ready() | |
| images = await _parse_images(request) | |
| image_captions = [] | |
| for filename, image in images: | |
| try: | |
| caption = generate_caption_text_safe(image) | |
| if not caption: | |
| raise AppError(f"Caption generation produced empty text for {filename}.", 500) | |
| image_captions.append({"filename": filename, "caption": caption}) | |
| except AppError: | |
| raise | |
| except Exception as exc: | |
| logger.error(f"Error generating caption for {filename}: {exc}") | |
| raise AppError(f"Failed to generate caption for {filename}.", 500) from exc | |
| caption_texts = [x["caption"] for x in image_captions] | |
| try: | |
| caption = summarize_captions(caption_texts) | |
| if not caption: | |
| raise AppError("Caption summarization produced empty text.", 500) | |
| except AppError: | |
| raise | |
| except Exception as exc: | |
| logger.error(f"Summarization error: {exc}") | |
| raise AppError("Failed to summarize captions.", 500) from exc | |
| mongo_payload = { | |
| "caption": caption, | |
| "source_filenames": [item["filename"] for item in image_captions], | |
| "image_captions": image_captions, | |
| "images_count": len(image_captions), | |
| "is_summarized": len(image_captions) > 1, | |
| "created_at": datetime.now(timezone.utc), | |
| } | |
| try: | |
| audio_file_id = insert_record(caption_collection, mongo_payload) | |
| except AppError: | |
| raise | |
| except Exception as exc: | |
| logger.error(f"Database insert error: {exc}") | |
| raise AppError("Failed to save record to database.", 503) from exc | |
| response_data = {**mongo_payload, "audio_file_id": audio_file_id} | |
| response_data.pop("_id", None) # Remove ObjectId as it is not JSON serializable | |
| response_data["created_at"] = response_data["created_at"].isoformat() | |
| return ok("Caption generated successfully.", response_data) |