update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -4,11 +4,9 @@ from wikipediaapi import Wikipedia | |
| 4 | 
             
            import textwrap
         | 
| 5 | 
             
            import numpy as np
         | 
| 6 | 
             
            from openai import OpenAI
         | 
| 7 | 
            -
            from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
         | 
| 8 | 
            -
            import os
         | 
| 9 |  | 
| 10 | 
             
            # Function to process the input and generate the output
         | 
| 11 | 
            -
            def process_query(wiki_page, embed_dim, query,  | 
| 12 | 
             
                model_mapping = {
         | 
| 13 | 
             
                    "Arabic-mpnet-base-all-nli-triplet": "Omartificial-Intelligence-Space/Arabic-mpnet-base-all-nli-triplet",
         | 
| 14 | 
             
                    "Arabic-all-nli-triplet-Matryoshka": "Omartificial-Intelligence-Space/Arabic-all-nli-triplet-Matryoshka",
         | 
| @@ -17,78 +15,70 @@ def process_query(wiki_page, embed_dim, query, mode): | |
| 17 | 
             
                    "Marbert-all-nli-triplet-Matryoshka": "Omartificial-Intelligence-Space/Marbert-all-nli-triplet-Matryoshka"
         | 
| 18 | 
             
                }
         | 
| 19 |  | 
| 20 | 
            -
                 | 
| 21 | 
            -
                 | 
| 22 | 
            -
             | 
| 23 | 
             
                wiki = Wikipedia('RAGBot/0.0', 'ar')
         | 
| 24 | 
             
                doc = wiki.page(wiki_page).text
         | 
| 25 | 
             
                paragraphs = doc.split('\n\n')  # chunking
         | 
|  | |
| 26 | 
             
                for i, p in enumerate(paragraphs):
         | 
| 27 | 
             
                    wrapped_text = textwrap.fill(p, width=100)
         | 
| 28 |  | 
| 29 | 
            -
                 | 
| 30 | 
            -
             | 
| 31 | 
            -
                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
                    query_embed = model.encode(query, normalize_embeddings=True)
         | 
| 35 | 
            -
                    similarities = np.dot(docs_embed, query_embed.T)
         | 
| 36 | 
            -
                    top_3_idx = np.argsort(similarities, axis=0)[-3:][::-1].tolist()
         | 
| 37 | 
            -
                    most_similar_documents = [paragraphs[idx] for idx in top_3_idx]
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                    CONTEXT = ""
         | 
| 40 | 
            -
                    for p in most_similar_documents:
         | 
| 41 | 
            -
                        wrapped_text = textwrap.fill(p, width=100)
         | 
| 42 | 
            -
                        CONTEXT += wrapped_text + "\n\n"
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                    prompt = f"""
         | 
| 45 | 
            -
                        use the following CONTEXT to answer the QUESTION at the end.
         | 
| 46 | 
            -
                        If you don't know the answer, just say that you don't know, don't try to make up an answer.
         | 
| 47 | 
            -
                        CONTEXT: {CONTEXT}
         | 
| 48 | 
            -
                        QUESTION: {query}
         | 
| 49 | 
            -
                    """
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                    if mode == "OpenAI":
         | 
| 52 | 
            -
                        client = OpenAI(api_key=openai_api_key)
         | 
| 53 | 
            -
                        response = client.chat.completions.create(
         | 
| 54 | 
            -
                            model="gpt-4",
         | 
| 55 | 
            -
                            messages=[
         | 
| 56 | 
            -
                                {"role": "user", "content": prompt},
         | 
| 57 | 
            -
                            ]
         | 
| 58 | 
            -
                        )
         | 
| 59 | 
            -
                        responses[model_name] = response.choices[0].message.content
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                    elif mode == "OpenSource":
         | 
| 62 | 
            -
                        tokenizer = AutoTokenizer.from_pretrained("google/gemini-2b", use_auth_token=hf_token)
         | 
| 63 | 
            -
                        model = AutoModelForCausalLM.from_pretrained("google/gemini-2b", use_auth_token=hf_token)
         | 
| 64 | 
            -
                        generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
         | 
| 65 | 
            -
                        response = generator(prompt, max_length=512, num_return_sequences=1)
         | 
| 66 | 
            -
                        responses[model_name] = response[0]['generated_text']
         | 
| 67 | 
            -
             | 
| 68 | 
            -
                return "\n\n".join([f"Model: {model_name}\nResponse: {response}" for model_name, response in responses.items()])
         | 
| 69 |  | 
| 70 | 
            -
             | 
| 71 | 
            -
                 | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
                     | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 82 | 
             
                )
         | 
| 83 |  | 
| 84 | 
            -
                 | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
                 | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 4 | 
             
            import textwrap
         | 
| 5 | 
             
            import numpy as np
         | 
| 6 | 
             
            from openai import OpenAI
         | 
|  | |
|  | |
| 7 |  | 
| 8 | 
             
            # Function to process the input and generate the output
         | 
| 9 | 
            +
            def process_query(wiki_page, model_name, embed_dim, query, api_key):
         | 
| 10 | 
             
                model_mapping = {
         | 
| 11 | 
             
                    "Arabic-mpnet-base-all-nli-triplet": "Omartificial-Intelligence-Space/Arabic-mpnet-base-all-nli-triplet",
         | 
| 12 | 
             
                    "Arabic-all-nli-triplet-Matryoshka": "Omartificial-Intelligence-Space/Arabic-all-nli-triplet-Matryoshka",
         | 
|  | |
| 15 | 
             
                    "Marbert-all-nli-triplet-Matryoshka": "Omartificial-Intelligence-Space/Marbert-all-nli-triplet-Matryoshka"
         | 
| 16 | 
             
                }
         | 
| 17 |  | 
| 18 | 
            +
                model_path = model_mapping[model_name]
         | 
| 19 | 
            +
                model = SentenceTransformer(model_path, trust_remote_code=True, truncate_dim=embed_dim)
         | 
|  | |
| 20 | 
             
                wiki = Wikipedia('RAGBot/0.0', 'ar')
         | 
| 21 | 
             
                doc = wiki.page(wiki_page).text
         | 
| 22 | 
             
                paragraphs = doc.split('\n\n')  # chunking
         | 
| 23 | 
            +
             | 
| 24 | 
             
                for i, p in enumerate(paragraphs):
         | 
| 25 | 
             
                    wrapped_text = textwrap.fill(p, width=100)
         | 
| 26 |  | 
| 27 | 
            +
                docs_embed = model.encode(paragraphs, normalize_embeddings=True)
         | 
| 28 | 
            +
                query_embed = model.encode(query, normalize_embeddings=True)
         | 
| 29 | 
            +
                similarities = np.dot(docs_embed, query_embed.T)
         | 
| 30 | 
            +
                top_3_idx = np.argsort(similarities, axis=0)[-3:][::-1].tolist()
         | 
| 31 | 
            +
                most_similar_documents = [paragraphs[idx] for idx in top_3_idx]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 32 |  | 
| 33 | 
            +
                CONTEXT = ""
         | 
| 34 | 
            +
                for i, p in enumerate(most_similar_documents):
         | 
| 35 | 
            +
                    wrapped_text = textwrap.fill(p, width=100)
         | 
| 36 | 
            +
                    CONTEXT += wrapped_text + "\n\n"
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                prompt = f"""
         | 
| 39 | 
            +
                    use the following CONTEXT to answer the QUESTION at the end.
         | 
| 40 | 
            +
                    If you don't know the answer, just say that you don't know, don't try to make up an answer.
         | 
| 41 | 
            +
                    CONTEXT: {CONTEXT}
         | 
| 42 | 
            +
                    QUESTION: {query}
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                client = OpenAI(api_key=api_key)
         | 
| 46 | 
            +
                response = client.chat.completions.create(
         | 
| 47 | 
            +
                    model="gpt-4o",
         | 
| 48 | 
            +
                    messages=[
         | 
| 49 | 
            +
                        {"role": "user", "content": prompt},
         | 
| 50 | 
            +
                    ]
         | 
| 51 | 
             
                )
         | 
| 52 |  | 
| 53 | 
            +
                return response.choices[0].message.content
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            # Define the interface
         | 
| 56 | 
            +
            wiki_page_input = gr.Textbox(label="Wikipedia Page (in Arabic)")
         | 
| 57 | 
            +
            query_input = gr.Textbox(label="Query (in Arabic)")
         | 
| 58 | 
            +
            api_key_input = gr.Textbox(label="OpenAI API Key", type="password")
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            model_choice = gr.Dropdown(
         | 
| 61 | 
            +
                choices=[
         | 
| 62 | 
            +
                    "Arabic-mpnet-base-all-nli-triplet",
         | 
| 63 | 
            +
                    "Arabic-all-nli-triplet-Matryoshka",
         | 
| 64 | 
            +
                    "Arabert-all-nli-triplet-Matryoshka",
         | 
| 65 | 
            +
                    "Arabic-labse-Matryoshka",
         | 
| 66 | 
            +
                    "Marbert-all-nli-triplet-Matryoshka"
         | 
| 67 | 
            +
                ], 
         | 
| 68 | 
            +
                label="Choose Embedding Model"
         | 
| 69 | 
            +
            )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            embed_dim_choice = gr.Dropdown(
         | 
| 72 | 
            +
                choices=[768, 512, 256, 128, 64],
         | 
| 73 | 
            +
                label="Embedding Dimension"
         | 
| 74 | 
            +
            )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            output_text = gr.Textbox(label="Output")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            gr.Interface(
         | 
| 79 | 
            +
                fn=process_query,
         | 
| 80 | 
            +
                inputs=[wiki_page_input, model_choice, embed_dim_choice, query_input, api_key_input],
         | 
| 81 | 
            +
                outputs=output_text,
         | 
| 82 | 
            +
                title="Arabic Wiki RAG",
         | 
| 83 | 
            +
                description="Choose a Wikipedia page, embedding model, and dimension to answer a query in Arabic."
         | 
| 84 | 
            +
            ).launch()
         | 
