| from fastapi import FastAPI, UploadFile, File, HTTPException |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
| from PIL import Image |
| import cv2 |
| import numpy as np |
| from io import BytesIO |
| import uvicorn |
|
|
| |
| processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
| model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
|
| app = FastAPI() |
|
|
| |
| def detect_lines(image, min_height=20, min_width=100): |
| """ |
| Detects lines of text in the given image. |
| """ |
| |
| image_np = np.array(image) |
|
|
| |
| gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) |
|
|
| |
| _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
|
|
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) |
| dilated = cv2.dilate(binary, kernel, iterations=1) |
|
|
| |
| contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| |
| bounding_boxes = [cv2.boundingRect(c) for c in contours] |
| bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1]) |
|
|
| |
| filtered_boxes = [] |
| for x, y, w, h in bounding_boxes: |
| if h >= min_height and w >= min_width: |
| filtered_boxes.append((x, y, w, h)) |
|
|
| |
| line_images = [] |
| for (x, y, w, h) in filtered_boxes: |
| line = image_np[y:y + h, x:x + w] |
| line_images.append(line) |
|
|
| return line_images |
|
|
| @app.post("/process_image") |
| async def process_image(file: UploadFile = File(...)): |
| """ |
| API endpoint to process the uploaded image and extract multiline text. |
| """ |
| try: |
| |
| contents = await file.read() |
| image = Image.open(BytesIO(contents)).convert("RGB") |
|
|
| |
| line_images = detect_lines(image, min_height=30, min_width=100) |
|
|
| |
| extracted_text = "" |
| for idx, line_img in enumerate(line_images): |
| |
| line_pil = Image.fromarray(line_img) |
|
|
| |
| pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values |
|
|
| |
| generated_ids = model.generate(pixel_values) |
| line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
| |
| extracted_text += f"{line_text}\n" |
|
|
| |
| return {"extracted_text": extracted_text} |
|
|
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|