from fastapi import FastAPI, UploadFile, File, HTTPException from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image import cv2 import numpy as np from io import BytesIO import uvicorn # Load the model and processor processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") app = FastAPI() # Helper function to preprocess the image and detect lines def detect_lines(image, min_height=20, min_width=100): """ Detects lines of text in the given image. """ # Convert the PIL image to a NumPy array image_np = np.array(image) # Convert to grayscale gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) # Apply binary thresholding _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) # Dilate to merge nearby text kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) dilated = cv2.dilate(binary, kernel, iterations=1) # Find contours contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Sort contours top-to-bottom bounding_boxes = [cv2.boundingRect(c) for c in contours] bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1]) # Sort by y-coordinate # Filter out small contours and merge nearby ones filtered_boxes = [] for x, y, w, h in bounding_boxes: if h >= min_height and w >= min_width: # Filter small boxes filtered_boxes.append((x, y, w, h)) # Extract individual lines as images line_images = [] for (x, y, w, h) in filtered_boxes: line = image_np[y:y + h, x:x + w] line_images.append(line) return line_images @app.post("/process_image") async def process_image(file: UploadFile = File(...)): """ API endpoint to process the uploaded image and extract multiline text. """ try: # Read the uploaded image contents = await file.read() image = Image.open(BytesIO(contents)).convert("RGB") # Detect lines in the image line_images = detect_lines(image, min_height=30, min_width=100) # Perform OCR on each detected line extracted_text = "" for idx, line_img in enumerate(line_images): # Convert the line image to PIL format line_pil = Image.fromarray(line_img) # Prepare the image for OCR pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values # Generate text from the line image generated_ids = model.generate(pixel_values) line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Append the extracted text extracted_text += f"{line_text}\n" # Return the extracted text as a JSON response return {"extracted_text": extracted_text} except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)