Donut 🍩, Document understanding transformer, is a new method of document understanding that utilizes an OCR
- Based on the transformer concept
- Experimented with the sample colab code
- Gradio is like streamlit
Samples and Results
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""colab-demo-for-donut-base-finetuned-cord-v2.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1o07hty-3OQTvGnc_7lgQFLvvKQuLjqiw?usp=sharing | |
""" | |
!pip install donut-python | |
!pip install gradio | |
import argparse | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from donut import DonutModel | |
def demo_process_vqa(input_img, question): | |
global pretrained_model, task_prompt, task_name | |
input_img = Image.fromarray(input_img) | |
user_prompt = task_prompt.replace("{user_input}", question) | |
output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] | |
return output | |
def demo_process(input_img): | |
global pretrained_model, task_prompt, task_name | |
input_img = Image.fromarray(input_img) | |
output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0] | |
return output | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--task", type=str, default="cord-v2") | |
parser.add_argument("--pretrained_path", type=str, default="naver-clova-ix/donut-base-finetuned-cord-v2") | |
args, left_argv = parser.parse_known_args() | |
task_name = args.task | |
if "docvqa" == task_name: | |
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" | |
else: # rvlcdip, cord, ... | |
task_prompt = f"<s_{task_name}>" | |
pretrained_model = DonutModel.from_pretrained(args.pretrained_path) | |
if torch.cuda.is_available(): | |
pretrained_model.half() | |
device = torch.device("cuda") | |
pretrained_model.to(device) | |
else: | |
pretrained_model.encoder.to(torch.bfloat16) | |
pretrained_model.eval() | |
demo = gr.Interface( | |
fn=demo_process_vqa if task_name == "docvqa" else demo_process, | |
inputs=["image", "text"] if task_name == "docvqa" else "image", | |
outputs="json", | |
title=f"Donut 🍩 demonstration for `{task_name}` task", | |
) | |
demo.launch() | |
Keep Exploring!!!
No comments:
Post a Comment