Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		Enzo Reis de Oliveira
		
	commited on
		
		
					Commit 
							
							·
						
						073cdd9
	
1
								Parent(s):
							
							cbc085f
								
Better error message for batch
Browse files
    	
        app.py
    CHANGED
    
    | @@ -4,14 +4,14 @@ import json | |
| 4 | 
             
            import pandas as pd
         | 
| 5 | 
             
            import gradio as gr
         | 
| 6 |  | 
| 7 | 
            -
            # 1)  | 
| 8 | 
             
            BASE_DIR = os.path.dirname(os.path.abspath(__file__))
         | 
| 9 | 
             
            INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
         | 
| 10 | 
             
            sys.path.insert(0, INFERENCE_PATH)
         | 
| 11 |  | 
| 12 | 
             
            from smi_ted_light.load import load_smi_ted
         | 
| 13 |  | 
| 14 | 
            -
            # 2)  | 
| 15 | 
             
            MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
         | 
| 16 | 
             
            model = load_smi_ted(
         | 
| 17 | 
             
                folder=MODEL_DIR,
         | 
| @@ -19,14 +19,15 @@ model = load_smi_ted( | |
| 19 | 
             
                vocab_filename="bert_vocab_curated.txt",
         | 
| 20 | 
             
            )
         | 
| 21 |  | 
| 22 | 
            -
            # 3) Single function to process either a single SMILES or a CSV of SMILES
         | 
| 23 | 
             
            def process_inputs(smiles: str, file_obj):
         | 
| 24 | 
            -
                #  | 
| 25 | 
             
                if file_obj is not None:
         | 
| 26 | 
             
                    try:
         | 
|  | |
| 27 | 
             
                        df_in = pd.read_csv(file_obj.name, sep=None, engine='python')
         | 
| 28 |  | 
| 29 | 
            -
                         | 
|  | |
| 30 | 
             
                        if not smiles_cols:
         | 
| 31 | 
             
                            return (
         | 
| 32 | 
             
                                "Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
         | 
| @@ -35,42 +36,68 @@ def process_inputs(smiles: str, file_obj): | |
| 35 | 
             
                        smiles_col = smiles_cols[0]
         | 
| 36 | 
             
                        smiles_list = df_in[smiles_col].astype(str).tolist()
         | 
| 37 |  | 
| 38 | 
            -
                         | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
                                gr.update(visible=False),
         | 
| 42 | 
            -
                            )
         | 
| 43 |  | 
| 44 | 
            -
                         | 
| 45 | 
             
                        for sm in smiles_list:
         | 
| 46 | 
            -
                             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 51 | 
             
                        out_df.to_csv("embeddings.csv", index=False)
         | 
| 52 |  | 
| 53 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 54 | 
             
                        return msg, gr.update(value="embeddings.csv", visible=True)
         | 
| 55 |  | 
| 56 | 
             
                    except Exception as e:
         | 
| 57 | 
             
                        return f"Error processing batch: {e}", gr.update(visible=False)
         | 
| 58 |  | 
| 59 | 
            -
                # Modo single
         | 
| 60 | 
             
                smiles = smiles.strip()
         | 
| 61 | 
             
                if not smiles:
         | 
| 62 | 
             
                    return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
         | 
| 63 | 
             
                try:
         | 
| 64 | 
             
                    vec = model.encode(smiles, return_torch=True)[0].tolist()
         | 
| 65 | 
            -
                    # Salva CSV com header
         | 
| 66 | 
             
                    cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
         | 
| 67 | 
             
                    df_out = pd.DataFrame([[smiles] + vec], columns=cols)
         | 
| 68 | 
             
                    df_out.to_csv("embeddings.csv", index=False)
         | 
| 69 | 
             
                    return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
         | 
| 70 | 
            -
                except Exception | 
| 71 | 
             
                    return f"The following input '{smiles}' is not a valid SMILES", gr.update(visible=False)
         | 
| 72 |  | 
| 73 | 
            -
             | 
|  | |
| 74 | 
             
            with gr.Blocks() as demo:
         | 
| 75 | 
             
                gr.Markdown(
         | 
| 76 | 
             
                    """
         | 
| @@ -88,7 +115,7 @@ with gr.Blocks() as demo: | |
| 88 | 
             
                generate_btn = gr.Button("Extract Embeddings")
         | 
| 89 |  | 
| 90 | 
             
                with gr.Row():
         | 
| 91 | 
            -
                    output_msg   = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines= | 
| 92 | 
             
                    download_csv = gr.File(label="Download embeddings.csv", visible=False)
         | 
| 93 |  | 
| 94 | 
             
                generate_btn.click(
         | 
|  | |
| 4 | 
             
            import pandas as pd
         | 
| 5 | 
             
            import gradio as gr
         | 
| 6 |  | 
| 7 | 
            +
            # 1) Ajusta o path antes de importar o loader
         | 
| 8 | 
             
            BASE_DIR = os.path.dirname(os.path.abspath(__file__))
         | 
| 9 | 
             
            INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
         | 
| 10 | 
             
            sys.path.insert(0, INFERENCE_PATH)
         | 
| 11 |  | 
| 12 | 
             
            from smi_ted_light.load import load_smi_ted
         | 
| 13 |  | 
| 14 | 
            +
            # 2) Carrega o modelo SMI-TED Light
         | 
| 15 | 
             
            MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
         | 
| 16 | 
             
            model = load_smi_ted(
         | 
| 17 | 
             
                folder=MODEL_DIR,
         | 
|  | |
| 19 | 
             
                vocab_filename="bert_vocab_curated.txt",
         | 
| 20 | 
             
            )
         | 
| 21 |  | 
|  | |
| 22 | 
             
            def process_inputs(smiles: str, file_obj):
         | 
| 23 | 
            +
                # Modo batch
         | 
| 24 | 
             
                if file_obj is not None:
         | 
| 25 | 
             
                    try:
         | 
| 26 | 
            +
                        # autodetecta delimitador (; ou , etc)
         | 
| 27 | 
             
                        df_in = pd.read_csv(file_obj.name, sep=None, engine='python')
         | 
| 28 |  | 
| 29 | 
            +
                        # procura coluna "smiles" (case‐insensitive)
         | 
| 30 | 
            +
                        smiles_cols = [c for c in df_in.columns if c.lower() == "smiles"]
         | 
| 31 | 
             
                        if not smiles_cols:
         | 
| 32 | 
             
                            return (
         | 
| 33 | 
             
                                "Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
         | 
|  | |
| 36 | 
             
                        smiles_col = smiles_cols[0]
         | 
| 37 | 
             
                        smiles_list = df_in[smiles_col].astype(str).tolist()
         | 
| 38 |  | 
| 39 | 
            +
                        out_records = []
         | 
| 40 | 
            +
                        invalid_smiles = []
         | 
| 41 | 
            +
                        embed_dim = None
         | 
|  | |
|  | |
| 42 |  | 
| 43 | 
            +
                        # para cada SMILES, tenta gerar embedding
         | 
| 44 | 
             
                        for sm in smiles_list:
         | 
| 45 | 
            +
                            try:
         | 
| 46 | 
            +
                                vec = model.encode(sm, return_torch=True)[0].tolist()
         | 
| 47 | 
            +
                                # guarda dimensão do vetor na primeira vez
         | 
| 48 | 
            +
                                if embed_dim is None:
         | 
| 49 | 
            +
                                    embed_dim = len(vec)
         | 
| 50 | 
            +
                                # monta registro válido
         | 
| 51 | 
            +
                                record = {"smiles": sm}
         | 
| 52 | 
            +
                                record.update({f"dim_{i}": v for i, v in enumerate(vec)})
         | 
| 53 | 
            +
                            except Exception:
         | 
| 54 | 
            +
                                # marca como inválido
         | 
| 55 | 
            +
                                invalid_smiles.append(sm)
         | 
| 56 | 
            +
                                # se já souber quantos dims, preenche com None
         | 
| 57 | 
            +
                                if embed_dim is not None:
         | 
| 58 | 
            +
                                    record = {"smiles": f"SMILES {sm} was invalid"}
         | 
| 59 | 
            +
                                    record.update({f"dim_{i}": None for i in range(embed_dim)})
         | 
| 60 | 
            +
                                else:
         | 
| 61 | 
            +
                                    # ainda não sabemos quantos dims: só guarda smiles
         | 
| 62 | 
            +
                                    record = {"smiles": f"SMILES {sm} was invalid"}
         | 
| 63 | 
            +
                            out_records.append(record)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                        # converte para DataFrame (vai unificar todas as colunas)
         | 
| 66 | 
            +
                        out_df = pd.DataFrame(out_records)
         | 
| 67 | 
             
                        out_df.to_csv("embeddings.csv", index=False)
         | 
| 68 |  | 
| 69 | 
            +
                        # monta mensagem de saída
         | 
| 70 | 
            +
                        total = len(smiles_list)
         | 
| 71 | 
            +
                        valid = total - len(invalid_smiles)
         | 
| 72 | 
            +
                        if invalid_smiles:
         | 
| 73 | 
            +
                            msg = (
         | 
| 74 | 
            +
                                f"{valid} SMILES were successfully processed, "
         | 
| 75 | 
            +
                                f"{len(invalid_smiles)} had errors:\n"
         | 
| 76 | 
            +
                                + "\n".join(invalid_smiles)
         | 
| 77 | 
            +
                            )
         | 
| 78 | 
            +
                        else:
         | 
| 79 | 
            +
                            msg = f"Processed batch of {valid} SMILES. Download embeddings.csv."
         | 
| 80 | 
            +
             | 
| 81 | 
             
                        return msg, gr.update(value="embeddings.csv", visible=True)
         | 
| 82 |  | 
| 83 | 
             
                    except Exception as e:
         | 
| 84 | 
             
                        return f"Error processing batch: {e}", gr.update(visible=False)
         | 
| 85 |  | 
| 86 | 
            +
                # Modo single (sem mudança)
         | 
| 87 | 
             
                smiles = smiles.strip()
         | 
| 88 | 
             
                if not smiles:
         | 
| 89 | 
             
                    return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
         | 
| 90 | 
             
                try:
         | 
| 91 | 
             
                    vec = model.encode(smiles, return_torch=True)[0].tolist()
         | 
|  | |
| 92 | 
             
                    cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
         | 
| 93 | 
             
                    df_out = pd.DataFrame([[smiles] + vec], columns=cols)
         | 
| 94 | 
             
                    df_out.to_csv("embeddings.csv", index=False)
         | 
| 95 | 
             
                    return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
         | 
| 96 | 
            +
                except Exception:
         | 
| 97 | 
             
                    return f"The following input '{smiles}' is not a valid SMILES", gr.update(visible=False)
         | 
| 98 |  | 
| 99 | 
            +
             | 
| 100 | 
            +
            # 4) Interface Gradio (sem mudanças)
         | 
| 101 | 
             
            with gr.Blocks() as demo:
         | 
| 102 | 
             
                gr.Markdown(
         | 
| 103 | 
             
                    """
         | 
|  | |
| 115 | 
             
                generate_btn = gr.Button("Extract Embeddings")
         | 
| 116 |  | 
| 117 | 
             
                with gr.Row():
         | 
| 118 | 
            +
                    output_msg   = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=4)
         | 
| 119 | 
             
                    download_csv = gr.File(label="Download embeddings.csv", visible=False)
         | 
| 120 |  | 
| 121 | 
             
                generate_btn.click(
         | 
