File size: 2,649 Bytes
86ecc5c
 
ee7b78c
86ecc5c
 
 
ba2ca67
86ecc5c
 
 
 
 
 
 
ba2ca67
c2607f2
 
86ecc5c
 
2f371ff
c2607f2
 
86ecc5c
 
2f371ff
c2607f2
 
86ecc5c
c2607f2
86ecc5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee7b78c
8202567
 
 
86ecc5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from ultralytics import YOLO
import random
import os
from pathlib import Path

model_path = 'https://huggingface.co/mayrajeo/yolov8-deadwood/resolve/main/models/'

def run_models(
        im:gr.Image=None,
        model_type:gr.Dropdown='YOLOv8n',
        conf_thr:gr.Slider=0.25
):
    
    hp_model = YOLO(f'{model_path}{model_type}_hp.pt')
    hp_model.to(device='cpu')
    hp_result = hp_model(im[:,:,::-1], conf=conf_thr)
    hp_im = hp_result[0].plot()

    spk_model = YOLO(f'{model_path}{model_type}_spk.pt')
    spk_model.to(device='cpu')
    spk_result = spk_model(im[:,:,::-1], conf=conf_thr)
    spk_im = spk_result[0].plot()

    both_model = YOLO(f'{model_path}{model_type}_both.pt')
    both_model.to(device='cpu')
    both_result = both_model(im[:,:,::-1], conf=conf_thr)
    both_im = both_result[0].plot()

    return [
        (hp_im[:,:,::-1], 'HP'),
        (spk_im[:,:,::-1], 'SPK'),
        (both_im[:,:,::-1], 'HP+SPK')
    ]

ex_dir = Path('examples')

loc = gr.Textbox(label='Location')

desc_str = """
Demo application for YOLOv8 models for deadwood segmentation from RGB UAV imagery. Results are shown on three different models: HP is trained only with data from Hiidenportti, 
SPK only with data from Sudenpesänkangas and HP+SPK is trained with both sites.
"""

with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown(desc_str)
    with gr.Row():
        with gr.Column(2):
            inp = gr.Image(label='Input image', sources='upload')
        with gr.Column(1):
            ex_list = random.sample([[ex_dir/i, i.split('_')[0]] for i in os.listdir(ex_dir)], 15)
            ex = gr.Examples(ex_list, inputs=[inp, loc],
                             cache_examples=False, examples_per_page=5,
                             label='Example UAV images')
        with gr.Column(1):
            loc.render()
            model = gr.Dropdown([
                        'YOLOv8n', 
                        'YOLOv8s', 
                        'YOLOv8m', 
                        'YOLOv8l', 
                        'YOLOv8x'
                        ],
                        value='YOLOv8n', label='Model')
            conf = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold')
            btn = gr.Button()
    with gr.Row():
        with gr.Column():
            gallery = gr.Gallery(
                label='Predictions', show_label=True, elem_id='gallery',
                columns=[3], rows=[1], object_fit='contain', interactive=False
            ) 
    btn.click(run_models, [inp, model, conf], gallery)

if __name__ == '__main__': demo.launch(share=False)