chinmayc3 commited on
Commit
84a9b26
·
1 Parent(s): 89059a4

converted app to multipage to improve performance

Browse files
Files changed (5) hide show
  1. app.py +84 -455
  2. enums.py +8 -0
  3. pages/about.py +57 -0
  4. pages/scoreboard.py +312 -0
  5. utils.py +20 -0
app.py CHANGED
@@ -1,42 +1,23 @@
1
- import streamlit as st
 
 
2
  import os
3
- import re
4
  import tempfile
5
- from audio_recorder_streamlit import audio_recorder
6
- import numpy as np
7
  import time
8
- import requests
9
- import io
10
- import base64
11
- import random
12
- import librosa
13
  import fsspec
 
 
14
  import pandas as pd
15
- import plotly.express as px
16
- import plotly.graph_objects as go
17
- import boto3
18
- import json
19
- from plotly.subplots import make_subplots
20
- from logger import logger
21
 
22
- fs = fsspec.filesystem(
23
- 's3',
24
- key=os.getenv("AWS_ACCESS_KEY"),
25
- secret=os.getenv("AWS_SECRET_KEY")
26
- )
27
-
28
- s3_client = boto3.client(
29
- 's3',
30
- aws_access_key_id=os.getenv("AWS_ACCESS_KEY"),
31
- aws_secret_access_key=os.getenv("AWS_SECRET_KEY")
32
- )
33
-
34
- SAVE_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('RESULTS_KEY')}"
35
- ELO_JSON_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('ELO_JSON_PATH')}"
36
- ELO_CSV_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('ELO_CSV_KEY')}"
37
- EMAIL_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('EMAILS_KEY')}"
38
- TEMP_DIR = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('AUDIOS_KEY')}"
39
- CREATE_TASK_URL = os.getenv("CREATE_TASK_URL")
40
 
41
  def create_files():
42
  if not fs.exists(SAVE_PATH):
@@ -314,324 +295,6 @@ def on_random_click():
314
 
315
  result_writer = ResultWriter(SAVE_PATH)
316
 
317
- def validate_email(email):
318
- pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
319
- return re.match(pattern, email) is not None
320
-
321
- def get_model_abbreviation(model_name):
322
- abbrev_map = {
323
- 'Ori Apex': 'Ori Apex',
324
- 'Ori Apex XT': 'Ori Apex XT',
325
- 'deepgram': 'DG',
326
- 'Ori Swift': 'Ori Swift',
327
- 'Ori Prime': 'Ori Prime',
328
- 'azure' : 'Azure'
329
- }
330
- return abbrev_map.get(model_name, model_name)
331
-
332
-
333
- def calculate_metrics(df):
334
- models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure']
335
- metrics = {}
336
-
337
- for model in models:
338
- appearances = df[f'{model}_appearance'].sum()
339
- wins = df[f'{model}_score'].sum()
340
- durations = df[df[f'{model}_appearance'] == 1][f'{model}_duration']
341
-
342
- if appearances > 0:
343
- win_rate = (wins / appearances) * 100
344
- avg_duration = durations.mean()
345
- duration_std = durations.std()
346
- else:
347
- win_rate = 0
348
- avg_duration = 0
349
- duration_std = 0
350
-
351
- metrics[model] = {
352
- 'appearances': appearances,
353
- 'wins': wins,
354
- 'win_rate': win_rate,
355
- 'avg_response_time': avg_duration,
356
- 'response_time_std': duration_std
357
- }
358
-
359
- return metrics
360
-
361
- def create_win_rate_chart(metrics):
362
- models = list(metrics.keys())
363
- win_rates = [metrics[model]['win_rate'] for model in models]
364
-
365
- fig = go.Figure(data=[
366
- go.Bar(
367
- x=[get_model_abbreviation(model) for model in models],
368
- y=win_rates,
369
- text=[f'{rate:.1f}%' for rate in win_rates],
370
- textposition='auto',
371
- hovertext=models
372
- )
373
- ])
374
-
375
- fig.update_layout(
376
- title='Win Rate by Model',
377
- xaxis_title='Model',
378
- yaxis_title='Win Rate (%)',
379
- yaxis_range=[0, 100]
380
- )
381
-
382
- return fig
383
-
384
- def create_appearance_chart(metrics):
385
- models = list(metrics.keys())
386
- appearances = [metrics[model]['appearances'] for model in models]
387
-
388
- fig = px.pie(
389
- values=appearances,
390
- names=[get_model_abbreviation(model) for model in models],
391
- title='Model Appearances Distribution',
392
- hover_data=[models]
393
- )
394
-
395
- return fig
396
-
397
- def create_head_to_head_matrix(df):
398
- models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure']
399
- matrix = np.zeros((len(models), len(models)))
400
-
401
- for i, model1 in enumerate(models):
402
- for j, model2 in enumerate(models):
403
- if i != j:
404
- matches = df[
405
- (df[f'{model1}_appearance'] == 1) &
406
- (df[f'{model2}_appearance'] == 1)
407
- ]
408
- if len(matches) > 0:
409
- win_rate = (matches[f'{model1}_score'].sum() / len(matches)) * 100
410
- matrix[i][j] = win_rate
411
-
412
- fig = go.Figure(data=go.Heatmap(
413
- z=matrix,
414
- x=[get_model_abbreviation(model) for model in models],
415
- y=[get_model_abbreviation(model) for model in models],
416
- text=[[f'{val:.1f}%' if val > 0 else '' for val in row] for row in matrix],
417
- texttemplate='%{text}',
418
- colorscale='RdYlBu',
419
- zmin=0,
420
- zmax=100
421
- ))
422
-
423
- fig.update_layout(
424
- title='Head-to-Head Win Rates',
425
- xaxis_title='Opponent Model',
426
- yaxis_title='Model'
427
- )
428
-
429
- return fig
430
-
431
- def create_elo_chart(df):
432
- fig = make_subplots(rows=1, cols=1,
433
- subplot_titles=('ELO Rating Progression'),
434
- row_heights=[0.7])
435
-
436
- for column in df.columns:
437
- fig.add_trace(
438
- go.Scatter(
439
- x=list(range(len(df))),
440
- y=df[column],
441
- name=column,
442
- mode='lines+markers'
443
- ),
444
- row=1, col=1
445
- )
446
-
447
- fig.update_layout(
448
- title='Model ELO Ratings Analysis',
449
- showlegend=True,
450
- hovermode='x unified'
451
- )
452
-
453
- fig.update_xaxes(title_text='Match Number', row=1, col=1)
454
- fig.update_xaxes(title_text='Models', row=2, col=1)
455
-
456
- return fig
457
-
458
- def create_metric_container(label, value, full_name=None):
459
- container = st.container()
460
- with container:
461
- st.markdown(f"**{label}**")
462
- if full_name:
463
- st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)
464
- st.caption(f"Full name: {full_name}")
465
- else:
466
- st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)
467
-
468
- def on_refresh_click():
469
- with st.spinner("Refreshing data... please wait"):
470
- with fs.open(SAVE_PATH, 'rb') as f:
471
- st.session_state.df = pd.read_csv(f)
472
-
473
- try:
474
- with fs.open(ELO_JSON_PATH,'r') as f:
475
- st.session_state.elo_json = json.load(f)
476
- except Exception as e:
477
- logger.error("Error while reading elo json file %s",e)
478
- st.session_state.elo_json = None
479
-
480
- try:
481
- with fs.open(ELO_CSV_PATH,'rb') as f:
482
- st.session_state.elo_df = pd.read_csv(f)
483
- except Exception as e:
484
- logger.error("Error while reading elo csv file %s",e)
485
- st.session_state.elo_df = None
486
-
487
- def dashboard():
488
- st.title('Model Arena Scoreboard')
489
-
490
- if "df" not in st.session_state:
491
- with fs.open(SAVE_PATH, 'rb') as f:
492
- st.session_state.df = pd.read_csv(f)
493
- if "elo_json" not in st.session_state:
494
- with fs.open(ELO_JSON_PATH,'r') as f:
495
- elo_json = json.load(f)
496
- st.session_state.elo_json = elo_json
497
- if "elo_df" not in st.session_state:
498
- with fs.open(ELO_CSV_PATH,'rb') as f:
499
- elo_df = pd.read_csv(f)
500
- st.session_state.elo_df = elo_df
501
-
502
- st.button("Refresh",on_click=on_refresh_click)
503
-
504
- if len(st.session_state.df) != 0:
505
- metrics = calculate_metrics(st.session_state.df)
506
-
507
- MODEL_DESCRIPTIONS = {
508
- "Ori Prime": "Foundational, large, and stable.",
509
- "Ori Swift": "Lighter and faster than Ori Prime.",
510
- "Ori Apex": "The top-performing model, fast and stable.",
511
- "Ori Apex XT": "Enhanced with more training, though slightly less stable than Ori Apex.",
512
- "DG" : "Deepgram Nova-2 API",
513
- "Azure" : "Azure Speech Services API"
514
- }
515
-
516
- st.header('Model Descriptions')
517
-
518
- cols = st.columns(2)
519
- for idx, (model, description) in enumerate(MODEL_DESCRIPTIONS.items()):
520
- with cols[idx % 2]:
521
- st.markdown(f"""
522
- <div style='padding: 1rem; border: 1px solid #e1e4e8; border-radius: 6px; margin-bottom: 1rem;'>
523
- <h3 style='margin: 0; margin-bottom: 0.5rem;'>{model}</h3>
524
- <p style='margin: 0; color: #6e7681;'>{description}</p>
525
- </div>
526
- """, unsafe_allow_html=True)
527
-
528
- st.header('Overall Performance')
529
-
530
- col1, col2, col3= st.columns(3)
531
-
532
- with col1:
533
- create_metric_container("Total Matches", len(st.session_state.df))
534
-
535
- # best_model = max(metrics.items(), key=lambda x: x[1]['win_rate'])[0]
536
- best_model = max(st.session_state.elo_json.items(), key=lambda x: x[1])[0] if st.session_state.elo_json else max(metrics.items(), key=lambda x: x[1]['win_rate'])[0]
537
- with col2:
538
- create_metric_container(
539
- "Best Model",
540
- get_model_abbreviation(best_model),
541
- full_name=best_model
542
- )
543
-
544
- most_appearances = max(metrics.items(), key=lambda x: x[1]['appearances'])[0]
545
- with col3:
546
- create_metric_container(
547
- "Most Used",
548
- get_model_abbreviation(most_appearances),
549
- full_name=most_appearances
550
- )
551
-
552
- metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
553
- metrics_df['win_rate'] = metrics_df['win_rate'].round(2)
554
- metrics_df.drop(["avg_response_time","response_time_std"],axis=1,inplace=True)
555
- metrics_df.index = [get_model_abbreviation(model) for model in metrics_df.index]
556
- st.dataframe(metrics_df,use_container_width=True)
557
-
558
- st.header('Win Rates')
559
- win_rate_chart = create_win_rate_chart(metrics)
560
- st.plotly_chart(win_rate_chart, use_container_width=True)
561
-
562
- st.header('Appearance Distribution')
563
- appearance_chart = create_appearance_chart(metrics)
564
- st.plotly_chart(appearance_chart, use_container_width=True)
565
-
566
- if st.session_state.elo_json is not None and st.session_state.elo_df is not None:
567
- st.header('Elo Ratings')
568
- st.dataframe(pd.DataFrame(st.session_state.elo_json,index=[0]),use_container_width=True)
569
- elo_progression_chart = create_elo_chart(st.session_state.elo_df)
570
- st.plotly_chart(elo_progression_chart, use_container_width=True)
571
-
572
- st.header('Head-to-Head Analysis')
573
- matrix_chart = create_head_to_head_matrix(st.session_state.df)
574
- st.plotly_chart(matrix_chart, use_container_width=True)
575
-
576
- st.header('Full Dataframe')
577
- st.dataframe(st.session_state.df.drop(['path','Ori Apex_duration', 'Ori Apex XT_duration', 'deepgram_duration', 'Ori Swift_duration', 'Ori Prime_duration','azure_duration','email'],axis=1),use_container_width=True)
578
- else:
579
- st.write("No Data to show")
580
-
581
- def about():
582
- st.title("About")
583
-
584
- st.markdown(
585
- """
586
- # Ori Speech-To-Text Arena
587
- """
588
- )
589
-
590
- st.markdown(
591
- """## Arena
592
- """
593
- )
594
-
595
- st.markdown(
596
- """
597
- * The Arena allows a user to record their audios, in which speech will be recognized by two randomly selected models. After listening to the audio, and evaluating the output from both the models, the user can vote on which transcription they prefer. Due to the risks of human bias and abuse, model names are revealed only after a vote is submitted."""
598
- )
599
-
600
- st.markdown(
601
- "## Scoreboard"
602
- )
603
-
604
- st.markdown(
605
- """ * The Scoreboard shows the performance of the models in the Arena. The user can see the overall performance of the models, the model with the highest win rate, and the model with the most appearances. The user can also see the win rates of each model, as well as the appearance distribution of each model."""
606
- )
607
-
608
- st.markdown(
609
- "## Contact Us"
610
- )
611
-
612
- st.markdown(
613
- "To inquire about our speech-to-text models and APIs, you can submit your email using the form below."
614
- )
615
-
616
- with st.form("login_form"):
617
- st.subheader("Please Enter you Email")
618
-
619
- email = st.text_input("Email")
620
-
621
- submit_button = st.form_submit_button("Submit")
622
-
623
- if submit_button:
624
- if not email:
625
- st.error("Please fill in all fields")
626
- else:
627
- if not validate_email(email):
628
- st.error("Please enter a valid email address")
629
- else:
630
- st.session_state.logged_in = True
631
- st.session_state.user_email = email
632
- write_email(st.session_state.user_email)
633
- st.success("Thanks for submitting your email, our team will be in touch with you shortly!")
634
-
635
  def main():
636
 
637
  st.title("⚔️ Ori Speech-To-Text Arena ⚔️")
@@ -659,123 +322,89 @@ def main():
659
  if "user_email" not in st.session_state:
660
  st.session_state.user_email = ""
661
 
662
- if 'logged_in' not in st.session_state:
663
- st.session_state.logged_in = False
664
-
665
- arena, scoreboard,about_tab = st.tabs(["Arena", "Scoreboard","About"])
666
-
667
- with arena:
668
- INSTR = """
669
- ## Instructions:
670
- * Record audio to recognise speech (or press 🎲 for random Audio).
671
- * Click on transcribe audio button to commence the transcription process.
672
- * Read the two options one after the other while listening to the audio.
673
- * Vote on which transcript you prefer.
674
- * Note:
675
- * Model names are revealed after the vote is cast.
676
- * Currently only Indian Hindi language is supported, and
677
- the results will be in Hinglish (Hindi in Latin script)
678
- * Random audios are only in hindi
679
- * It may take up to 30 seconds for speech recognition in some cases.
680
- """.strip()
681
-
682
- st.markdown(INSTR)
683
-
684
- col1, col2 = st.columns([1, 1])
685
-
686
- with col1:
687
- st.markdown("### Record Audio")
688
- with st.container():
689
- audio_bytes = audio_recorder(
690
- text="🎙️ Click to Record",
691
- pause_threshold=3,
692
- icon_size="2x",
693
- key="audio_recorder",
694
- sample_rate=16_000
695
- )
696
- if audio_bytes and audio_bytes != st.session_state.get('last_recorded_audio'):
697
- reset_state()
698
- st.session_state.last_recorded_audio = audio_bytes
699
- st.session_state.audio = {"data":audio_bytes,"format":"audio/wav"}
700
- st.session_state.current_audio_type = "recorded"
701
- st.session_state.has_audio = True
702
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
703
- tmp_file.write(audio_bytes)
704
- os.makedirs(TEMP_DIR, exist_ok=True)
705
- s3_client.put_object(Bucket=os.getenv('AWS_BUCKET_NAME'), Key=f"{os.getenv('AUDIOS_KEY')}/{tmp_file.name.split('/')[-1]}", Body=audio_bytes)
706
- st.session_state.audio_path = tmp_file.name
707
- st.session_state.option_selected = None
708
-
709
- with col2:
710
- st.markdown("### Random Audio Example")
711
- with st.container():
712
- st.button("🎲 Random Audio",on_click=on_random_click)
713
-
714
- if st.session_state.has_audio:
715
- st.audio(**st.session_state.audio)
716
 
 
 
717
 
718
- with st.container():
719
- st.button("📝 Transcribe Audio",on_click=on_click_transcribe,use_container_width=True)
720
 
721
- text_containers = st.columns([1, 1])
722
- name_containers = st.columns([1, 1])
723
 
724
- with text_containers[0]:
725
- st.text_area("Option 1", value=st.session_state.option_1, height=300)
726
 
727
- with text_containers[1]:
728
- st.text_area("Option 2", value=st.session_state.option_2, height=300)
729
 
730
- with name_containers[0]:
731
- if st.session_state.option_1_model_name_state:
732
- st.markdown(f"<div style='text-align: center'>{st.session_state.option_1_model_name_state}</div>", unsafe_allow_html=True)
733
 
734
- with name_containers[1]:
735
- if st.session_state.option_2_model_name_state:
736
- st.markdown(f"<div style='text-align: center'>{st.session_state.option_2_model_name_state}</div>", unsafe_allow_html=True)
737
 
738
- c1, c2, c3, c4 = st.columns(4)
 
 
739
 
740
- with c1:
741
- st.button("Prefer Option 1",on_click=on_option_1_click)
742
 
743
- with c2:
744
- st.button("Prefer Option 2",on_click=on_option_2_click)
745
 
746
- with c3:
747
- st.button("Prefer Both",on_click=on_option_both_click)
748
 
749
- with c4:
750
- st.button("Prefer None",on_click=on_option_none_click)
751
 
752
- with scoreboard:
753
- if st.session_state.logged_in or os.getenv("IS_TEST"):
754
- dashboard()
755
- else:
756
- with st.form("contact_us_form"):
757
- st.subheader("Please Enter you Email")
758
-
759
- email = st.text_input("Email")
760
-
761
- submit_button = st.form_submit_button("Submit")
762
-
763
- if submit_button:
764
- if not email:
765
- st.error("Please fill in all fields")
766
- else:
767
- if not validate_email(email):
768
- st.error("Please enter a valid email address")
769
- else:
770
- st.session_state.logged_in = True
771
- st.session_state.user_email = email
772
- write_email(st.session_state.user_email)
773
- st.success("Thanks for submitting your email")
774
- if st.session_state.logged_in:
775
- dashboard()
776
-
777
- with about_tab:
778
- about()
779
 
780
  create_files()
781
  main()
 
1
+ import base64
2
+ import io
3
+ import json
4
  import os
5
+ import random
6
  import tempfile
 
 
7
  import time
8
+
9
+ import boto3
 
 
 
10
  import fsspec
11
+ import librosa
12
+ import numpy as np
13
  import pandas as pd
14
+ import requests
15
+ import streamlit as st
16
+ from audio_recorder_streamlit import audio_recorder
 
 
 
17
 
18
+ from logger import logger
19
+ from utils import fs,s3_client
20
+ from enums import SAVE_PATH, ELO_JSON_PATH, ELO_CSV_PATH, EMAIL_PATH, TEMP_DIR, CREATE_TASK_URL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def create_files():
23
  if not fs.exists(SAVE_PATH):
 
295
 
296
  result_writer = ResultWriter(SAVE_PATH)
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  def main():
299
 
300
  st.title("⚔️ Ori Speech-To-Text Arena ⚔️")
 
322
  if "user_email" not in st.session_state:
323
  st.session_state.user_email = ""
324
 
325
+ INSTR = """
326
+ ## Instructions:
327
+ * Record audio to recognise speech (or press 🎲 for random Audio).
328
+ * Click on transcribe audio button to commence the transcription process.
329
+ * Read the two options one after the other while listening to the audio.
330
+ * Vote on which transcript you prefer.
331
+ * Note:
332
+ * Model names are revealed after the vote is cast.
333
+ * Currently only Indian Hindi language is supported, and
334
+ the results will be in Hinglish (Hindi in Latin script)
335
+ * Random audios are only in hindi
336
+ * It may take up to 30 seconds for speech recognition in some cases.
337
+ """.strip()
338
+
339
+ st.markdown(INSTR)
340
+
341
+ col1, col2 = st.columns([1, 1])
342
+
343
+ with col1:
344
+ st.markdown("### Record Audio")
345
+ with st.container():
346
+ audio_bytes = audio_recorder(
347
+ text="🎙️ Click to Record",
348
+ pause_threshold=3,
349
+ icon_size="2x",
350
+ key="audio_recorder",
351
+ sample_rate=16_000
352
+ )
353
+ if audio_bytes and audio_bytes != st.session_state.get('last_recorded_audio'):
354
+ reset_state()
355
+ st.session_state.last_recorded_audio = audio_bytes
356
+ st.session_state.audio = {"data":audio_bytes,"format":"audio/wav"}
357
+ st.session_state.current_audio_type = "recorded"
358
+ st.session_state.has_audio = True
359
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
360
+ tmp_file.write(audio_bytes)
361
+ os.makedirs(TEMP_DIR, exist_ok=True)
362
+ # s3_client.put_object(Bucket=os.getenv('AWS_BUCKET_NAME'), Key=f"{os.getenv('AUDIOS_KEY')}/{tmp_file.name.split('/')[-1]}", Body=audio_bytes)
363
+ st.session_state.audio_path = tmp_file.name
364
+ st.session_state.option_selected = None
365
+
366
+ with col2:
367
+ st.markdown("### Random Audio Example")
368
+ with st.container():
369
+ st.button("🎲 Random Audio",on_click=on_random_click)
 
 
 
 
 
 
 
 
 
370
 
371
+ if st.session_state.has_audio:
372
+ st.audio(**st.session_state.audio)
373
 
 
 
374
 
375
+ with st.container():
376
+ st.button("📝 Transcribe Audio",on_click=on_click_transcribe,use_container_width=True)
377
 
378
+ text_containers = st.columns([1, 1])
379
+ name_containers = st.columns([1, 1])
380
 
381
+ with text_containers[0]:
382
+ st.text_area("Option 1", value=st.session_state.option_1, height=300)
383
 
384
+ with text_containers[1]:
385
+ st.text_area("Option 2", value=st.session_state.option_2, height=300)
 
386
 
387
+ with name_containers[0]:
388
+ if st.session_state.option_1_model_name_state:
389
+ st.markdown(f"<div style='text-align: center'>{st.session_state.option_1_model_name_state}</div>", unsafe_allow_html=True)
390
 
391
+ with name_containers[1]:
392
+ if st.session_state.option_2_model_name_state:
393
+ st.markdown(f"<div style='text-align: center'>{st.session_state.option_2_model_name_state}</div>", unsafe_allow_html=True)
394
 
395
+ c1, c2, c3, c4 = st.columns(4)
 
396
 
397
+ with c1:
398
+ st.button("Prefer Option 1",on_click=on_option_1_click)
399
 
400
+ with c2:
401
+ st.button("Prefer Option 2",on_click=on_option_2_click)
402
 
403
+ with c3:
404
+ st.button("Prefer Both",on_click=on_option_both_click)
405
 
406
+ with c4:
407
+ st.button("Prefer None",on_click=on_option_none_click)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
  create_files()
410
  main()
enums.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ SAVE_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('RESULTS_KEY')}"
4
+ ELO_JSON_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('ELO_JSON_PATH')}"
5
+ ELO_CSV_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('ELO_CSV_KEY')}"
6
+ EMAIL_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('EMAILS_KEY')}"
7
+ TEMP_DIR = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('AUDIOS_KEY')}"
8
+ CREATE_TASK_URL = os.getenv("CREATE_TASK_URL")
pages/about.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pages.scoreboard import validate_email
3
+ from pages.scoreboard import write_email
4
+
5
+
6
+ st.title("About")
7
+
8
+ st.markdown(
9
+ """
10
+ # Ori Speech-To-Text Arena
11
+ """
12
+ )
13
+
14
+ st.markdown(
15
+ """## Arena
16
+ """
17
+ )
18
+
19
+ st.markdown(
20
+ """
21
+ * The Arena allows a user to record their audios, in which speech will be recognized by two randomly selected models. After listening to the audio, and evaluating the output from both the models, the user can vote on which transcription they prefer. Due to the risks of human bias and abuse, model names are revealed only after a vote is submitted."""
22
+ )
23
+
24
+ st.markdown(
25
+ "## Scoreboard"
26
+ )
27
+
28
+ st.markdown(
29
+ """ * The Scoreboard shows the performance of the models in the Arena. The user can see the overall performance of the models, the model with the highest win rate, and the model with the most appearances. The user can also see the win rates of each model, as well as the appearance distribution of each model."""
30
+ )
31
+
32
+ st.markdown(
33
+ "## Contact Us"
34
+ )
35
+
36
+ st.markdown(
37
+ "To inquire about our speech-to-text models and APIs, you can submit your email using the form below."
38
+ )
39
+
40
+ with st.form("login_form"):
41
+ st.subheader("Please Enter you Email")
42
+
43
+ email = st.text_input("Email")
44
+
45
+ submit_button = st.form_submit_button("Submit")
46
+
47
+ if submit_button:
48
+ if not email:
49
+ st.error("Please fill in all fields")
50
+ else:
51
+ if not validate_email(email):
52
+ st.error("Please enter a valid email address")
53
+ else:
54
+ st.session_state.logged_in = True
55
+ st.session_state.user_email = email
56
+ write_email(st.session_state.user_email)
57
+ st.success("Thanks for submitting your email, our team will be in touch with you shortly!")
pages/scoreboard.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ from logger import logger
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ from plotly.subplots import make_subplots
9
+ import json
10
+ from utils import fs,validate_email
11
+ from enums import SAVE_PATH, ELO_JSON_PATH, ELO_CSV_PATH, EMAIL_PATH
12
+
13
+
14
+ def write_email(email):
15
+ if fs.exists(EMAIL_PATH):
16
+ with fs.open(EMAIL_PATH, 'rb') as f:
17
+ existing_content = f.read().decode('utf-8')
18
+ else:
19
+ existing_content = ''
20
+
21
+ new_content = existing_content + email + '\n'
22
+
23
+ with fs.open(EMAIL_PATH, 'wb') as f:
24
+ f.write(new_content.encode('utf-8'))
25
+
26
+ def get_model_abbreviation(model_name):
27
+ abbrev_map = {
28
+ 'Ori Apex': 'Ori Apex',
29
+ 'Ori Apex XT': 'Ori Apex XT',
30
+ 'deepgram': 'DG',
31
+ 'Ori Swift': 'Ori Swift',
32
+ 'Ori Prime': 'Ori Prime',
33
+ 'azure' : 'Azure'
34
+ }
35
+ return abbrev_map.get(model_name, model_name)
36
+
37
+
38
+ def calculate_metrics(df):
39
+ models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure']
40
+ metrics = {}
41
+
42
+ for model in models:
43
+ appearances = df[f'{model}_appearance'].sum()
44
+ wins = df[f'{model}_score'].sum()
45
+ durations = df[df[f'{model}_appearance'] == 1][f'{model}_duration']
46
+
47
+ if appearances > 0:
48
+ win_rate = (wins / appearances) * 100
49
+ avg_duration = durations.mean()
50
+ duration_std = durations.std()
51
+ else:
52
+ win_rate = 0
53
+ avg_duration = 0
54
+ duration_std = 0
55
+
56
+ metrics[model] = {
57
+ 'appearances': appearances,
58
+ 'wins': wins,
59
+ 'win_rate': win_rate,
60
+ 'avg_response_time': avg_duration,
61
+ 'response_time_std': duration_std
62
+ }
63
+
64
+ return metrics
65
+
66
+ def create_win_rate_chart(metrics):
67
+ models = list(metrics.keys())
68
+ win_rates = [metrics[model]['win_rate'] for model in models]
69
+
70
+ fig = go.Figure(data=[
71
+ go.Bar(
72
+ x=[get_model_abbreviation(model) for model in models],
73
+ y=win_rates,
74
+ text=[f'{rate:.1f}%' for rate in win_rates],
75
+ textposition='auto',
76
+ hovertext=models
77
+ )
78
+ ])
79
+
80
+ fig.update_layout(
81
+ title='Win Rate by Model',
82
+ xaxis_title='Model',
83
+ yaxis_title='Win Rate (%)',
84
+ yaxis_range=[0, 100]
85
+ )
86
+
87
+ return fig
88
+
89
+ def create_appearance_chart(metrics):
90
+ models = list(metrics.keys())
91
+ appearances = [metrics[model]['appearances'] for model in models]
92
+
93
+ fig = px.pie(
94
+ values=appearances,
95
+ names=[get_model_abbreviation(model) for model in models],
96
+ title='Model Appearances Distribution',
97
+ hover_data=[models]
98
+ )
99
+
100
+ return fig
101
+
102
+ def create_head_to_head_matrix(df):
103
+ models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure']
104
+ matrix = np.zeros((len(models), len(models)))
105
+
106
+ for i, model1 in enumerate(models):
107
+ for j, model2 in enumerate(models):
108
+ if i != j:
109
+ matches = df[
110
+ (df[f'{model1}_appearance'] == 1) &
111
+ (df[f'{model2}_appearance'] == 1)
112
+ ]
113
+ if len(matches) > 0:
114
+ win_rate = (matches[f'{model1}_score'].sum() / len(matches)) * 100
115
+ matrix[i][j] = win_rate
116
+
117
+ fig = go.Figure(data=go.Heatmap(
118
+ z=matrix,
119
+ x=[get_model_abbreviation(model) for model in models],
120
+ y=[get_model_abbreviation(model) for model in models],
121
+ text=[[f'{val:.1f}%' if val > 0 else '' for val in row] for row in matrix],
122
+ texttemplate='%{text}',
123
+ colorscale='RdYlBu',
124
+ zmin=0,
125
+ zmax=100
126
+ ))
127
+
128
+ fig.update_layout(
129
+ title='Head-to-Head Win Rates',
130
+ xaxis_title='Opponent Model',
131
+ yaxis_title='Model'
132
+ )
133
+
134
+ return fig
135
+
136
+ def create_elo_chart(df):
137
+ fig = make_subplots(rows=1, cols=1,
138
+ subplot_titles=('ELO Rating Progression'),
139
+ row_heights=[0.7])
140
+
141
+ for column in df.columns:
142
+ fig.add_trace(
143
+ go.Scatter(
144
+ x=list(range(len(df))),
145
+ y=df[column],
146
+ name=column,
147
+ mode='lines+markers'
148
+ ),
149
+ row=1, col=1
150
+ )
151
+
152
+ fig.update_layout(
153
+ title='Model ELO Ratings Analysis',
154
+ showlegend=True,
155
+ hovermode='x unified'
156
+ )
157
+
158
+ fig.update_xaxes(title_text='Match Number', row=1, col=1)
159
+ fig.update_xaxes(title_text='Models', row=2, col=1)
160
+
161
+ return fig
162
+
163
+ def create_metric_container(label, value, full_name=None):
164
+ container = st.container()
165
+ with container:
166
+ st.markdown(f"**{label}**")
167
+ if full_name:
168
+ st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)
169
+ st.caption(f"Full name: {full_name}")
170
+ else:
171
+ st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)
172
+
173
+ def on_refresh_click():
174
+ with st.spinner("Refreshing data... please wait"):
175
+ with fs.open(SAVE_PATH, 'rb') as f:
176
+ st.session_state.df = pd.read_csv(f)
177
+
178
+ try:
179
+ with fs.open(ELO_JSON_PATH,'r') as f:
180
+ st.session_state.elo_json = json.load(f)
181
+ except Exception as e:
182
+ logger.error("Error while reading elo json file %s",e)
183
+ st.session_state.elo_json = None
184
+
185
+ try:
186
+ with fs.open(ELO_CSV_PATH,'rb') as f:
187
+ st.session_state.elo_df = pd.read_csv(f)
188
+ except Exception as e:
189
+ logger.error("Error while reading elo csv file %s",e)
190
+ st.session_state.elo_df = None
191
+
192
+ def dashboard():
193
+ st.title('Model Arena Scoreboard')
194
+
195
+ if "df" not in st.session_state:
196
+ with fs.open(SAVE_PATH, 'rb') as f:
197
+ st.session_state.df = pd.read_csv(f)
198
+ if "elo_json" not in st.session_state:
199
+ with fs.open(ELO_JSON_PATH,'r') as f:
200
+ elo_json = json.load(f)
201
+ st.session_state.elo_json = elo_json
202
+ if "elo_df" not in st.session_state:
203
+ with fs.open(ELO_CSV_PATH,'rb') as f:
204
+ elo_df = pd.read_csv(f)
205
+ st.session_state.elo_df = elo_df
206
+
207
+ st.button("Refresh",on_click=on_refresh_click)
208
+
209
+ if len(st.session_state.df) != 0:
210
+ metrics = calculate_metrics(st.session_state.df)
211
+
212
+ MODEL_DESCRIPTIONS = {
213
+ "Ori Prime": "Foundational, large, and stable.",
214
+ "Ori Swift": "Lighter and faster than Ori Prime.",
215
+ "Ori Apex": "The top-performing model, fast and stable.",
216
+ "Ori Apex XT": "Enhanced with more training, though slightly less stable than Ori Apex.",
217
+ "DG" : "Deepgram Nova-2 API",
218
+ "Azure" : "Azure Speech Services API"
219
+ }
220
+
221
+ st.header('Model Descriptions')
222
+
223
+ cols = st.columns(2)
224
+ for idx, (model, description) in enumerate(MODEL_DESCRIPTIONS.items()):
225
+ with cols[idx % 2]:
226
+ st.markdown(f"""
227
+ <div style='padding: 1rem; border: 1px solid #e1e4e8; border-radius: 6px; margin-bottom: 1rem;'>
228
+ <h3 style='margin: 0; margin-bottom: 0.5rem;'>{model}</h3>
229
+ <p style='margin: 0; color: #6e7681;'>{description}</p>
230
+ </div>
231
+ """, unsafe_allow_html=True)
232
+
233
+ st.header('Overall Performance')
234
+
235
+ col1, col2, col3= st.columns(3)
236
+
237
+ with col1:
238
+ create_metric_container("Total Matches", len(st.session_state.df))
239
+
240
+ # best_model = max(metrics.items(), key=lambda x: x[1]['win_rate'])[0]
241
+ best_model = max(st.session_state.elo_json.items(), key=lambda x: x[1])[0] if st.session_state.elo_json else max(metrics.items(), key=lambda x: x[1]['win_rate'])[0]
242
+ with col2:
243
+ create_metric_container(
244
+ "Best Model",
245
+ get_model_abbreviation(best_model),
246
+ full_name=best_model
247
+ )
248
+
249
+ most_appearances = max(metrics.items(), key=lambda x: x[1]['appearances'])[0]
250
+ with col3:
251
+ create_metric_container(
252
+ "Most Used",
253
+ get_model_abbreviation(most_appearances),
254
+ full_name=most_appearances
255
+ )
256
+
257
+ metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
258
+ metrics_df['win_rate'] = metrics_df['win_rate'].round(2)
259
+ metrics_df.drop(["avg_response_time","response_time_std"],axis=1,inplace=True)
260
+ metrics_df.index = [get_model_abbreviation(model) for model in metrics_df.index]
261
+ st.dataframe(metrics_df,use_container_width=True)
262
+
263
+ st.header('Win Rates')
264
+ win_rate_chart = create_win_rate_chart(metrics)
265
+ st.plotly_chart(win_rate_chart, use_container_width=True)
266
+
267
+ st.header('Appearance Distribution')
268
+ appearance_chart = create_appearance_chart(metrics)
269
+ st.plotly_chart(appearance_chart, use_container_width=True)
270
+
271
+ if st.session_state.elo_json is not None and st.session_state.elo_df is not None:
272
+ st.header('Elo Ratings')
273
+ st.dataframe(pd.DataFrame(st.session_state.elo_json,index=[0]),use_container_width=True)
274
+ elo_progression_chart = create_elo_chart(st.session_state.elo_df)
275
+ st.plotly_chart(elo_progression_chart, use_container_width=True)
276
+
277
+ st.header('Head-to-Head Analysis')
278
+ matrix_chart = create_head_to_head_matrix(st.session_state.df)
279
+ st.plotly_chart(matrix_chart, use_container_width=True)
280
+
281
+ st.header('Full Dataframe')
282
+ st.dataframe(st.session_state.df.drop(['path','Ori Apex_duration', 'Ori Apex XT_duration', 'deepgram_duration', 'Ori Swift_duration', 'Ori Prime_duration','azure_duration','email'],axis=1),use_container_width=True)
283
+ else:
284
+ st.write("No Data to show")
285
+
286
+ if __name__ == "__main__":
287
+ if 'logged_in' not in st.session_state:
288
+ st.session_state.logged_in = False
289
+
290
+ if st.session_state.logged_in or os.getenv("IS_TEST"):
291
+ dashboard()
292
+ else:
293
+ with st.form("contact_us_form"):
294
+ st.subheader("Please Enter you Email")
295
+
296
+ email = st.text_input("Email")
297
+
298
+ submit_button = st.form_submit_button("Submit")
299
+
300
+ if submit_button:
301
+ if not email:
302
+ st.error("Please fill in all fields")
303
+ else:
304
+ if not validate_email(email):
305
+ st.error("Please enter a valid email address")
306
+ else:
307
+ st.session_state.logged_in = True
308
+ st.session_state.user_email = email
309
+ write_email(st.session_state.user_email)
310
+ st.success("Thanks for submitting your email")
311
+ if st.session_state.logged_in:
312
+ dashboard()
utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fsspec
2
+ import boto3
3
+ import os
4
+ import re
5
+
6
+ fs = fsspec.filesystem(
7
+ 's3',
8
+ key=os.getenv("AWS_ACCESS_KEY"),
9
+ secret=os.getenv("AWS_SECRET_KEY")
10
+ )
11
+
12
+ s3_client = boto3.client(
13
+ 's3',
14
+ aws_access_key_id=os.getenv("AWS_ACCESS_KEY"),
15
+ aws_secret_access_key=os.getenv("AWS_SECRET_KEY")
16
+ )
17
+
18
+ def validate_email(email):
19
+ pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
20
+ return re.match(pattern, email) is not None