import csv import random import pandas as pd import gradio as gr from utils import clean_dir, TMP_DIR, EN_US ZH2EN = { "输入参与者数量": "Number of participants", "输入分组比率 (格式为用:隔开的数字,生成随机分组数据)": "Grouping ratio (numbers separated by : to generate randomized controlled trial)", "状态栏": "Status", "下载随机分组数据 CSV": "Download data CSV", "随机分组数据预览": "Data preview", } def _L(zh_txt: str): return ZH2EN[zh_txt] if EN_US else zh_txt def list_to_csv(list_of_dicts: list, filename: str): keys = dict(list_of_dicts[0]).keys() # 将列表中的字典写入 CSV 文件 with open(filename, "w", newline="", encoding="utf-8") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=keys) writer.writeheader() for data in list_of_dicts: writer.writerow(data) def random_allocate(participants: int, ratio: list, out_csv: str): splits = [0] total = sum(ratio) for i, r in enumerate(ratio): splits.append(splits[i] + int(1.0 * r / total * participants)) splits[-1] = participants partist = list(range(1, participants + 1)) random.shuffle(partist) allocation = [] groups = len(ratio) for i in range(groups): start = splits[i] end = splits[i + 1] for participant in partist[start:end]: allocation.append({"id": participant, "group": i + 1}) sorted_data = sorted(allocation, key=lambda x: x["id"]) list_to_csv(sorted_data, out_csv) return out_csv, pd.DataFrame(sorted_data) # outer func def infer(participants: float, ratios: str, cache=f"{TMP_DIR}/rct"): ratio = [] status = "Success" out_csv = previews = None try: ratio_list = ratios.split(":") clean_dir(cache) for r in ratio_list: current_ratio = float(r.strip()) if current_ratio > 0: ratio.append(current_ratio) out_csv, previews = random_allocate( int(participants), ratio, f"{cache}/output.csv" ) except Exception as e: status = f"{e}" return status, out_csv, previews def rct_generator(): return gr.Interface( fn=infer, inputs=[ gr.Number(label=_L("输入参与者数量"), value=10), gr.Textbox( label=_L("输入分组比率 (格式为用:隔开的数字,生成随机分组数据)"), value="8:1:1", ), ], outputs=[ gr.Textbox(label=_L("状态栏"), show_copy_button=True), gr.File(label=_L("下载随机分组数据 CSV")), gr.Dataframe(label=_L("随机分组数据预览")), ], flagging_mode="never", )