RCT_generator / app.py
admin
sync
f507537
raw
history blame
No virus
2.62 kB
import os
import csv
import random
import shutil
import pandas as pd
import gradio as gr
DATA_DIR = "./data"
def list_to_csv(list_of_dicts: list, filename: str):
keys = dict(list_of_dicts[0]).keys()
# 将列表中的字典写入 CSV 文件
with open(filename, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=keys)
writer.writeheader()
for data in list_of_dicts:
writer.writerow(data)
return filename
def random_allocation(participants: int, ratio: list):
total = sum(ratio)
splits = [0]
for i, r in enumerate(ratio):
splits.append(splits[i] + int(1.0 * r / total * participants))
splits[-1] = participants
partist = list(range(1, participants + 1))
random.shuffle(partist)
allocation = []
groups = len(ratio)
for i in range(groups):
start = splits[i]
end = splits[i + 1]
for participant in partist[start:end]:
allocation.append({"id": participant, "group": i + 1})
sorted_data = sorted(allocation, key=lambda x: x["id"])
filename = list_to_csv(sorted_data, f"{DATA_DIR}/output.csv")
return filename, pd.DataFrame(sorted_data)
def inference(participants: float, ratios: str):
if os.path.exists(DATA_DIR):
shutil.rmtree(DATA_DIR)
os.makedirs(DATA_DIR, exist_ok=True)
ratio_list = ratios.split(":")
ratio = []
try:
for r in ratio_list:
current_ratio = float(r.strip())
if current_ratio > 0:
ratio.append(current_ratio)
except Exception:
print("Invalid input of ratio!")
return random_allocation(int(participants), ratio)
if __name__ == "__main__":
gr.Interface(
fn=inference,
inputs=[
gr.Number(
label="输入参与者数量 (Number of participants)",
value=10,
),
gr.Textbox(label="输入分组比率 (Grouping ratio)", value="8:1:1"),
],
outputs=[
gr.components.File(label="下载随机分组数据 CSV (Download data CSV)"),
gr.Dataframe(label="随机分组数据预览 (Data preview)"),
],
title="随机对照试验随机数生成器<br>Randomized Controlled Trial Generator",
description="输入参与者数量和分组比率,格式为用:隔开的数字,生成随机分组数据。<br>Enter the number of participants and the grouping ratio in the format of numbers separated by : to generate randomized grouping data.",
allow_flagging=False,
).launch()