Rajesh3338 commited on
Commit
16fd3b1
·
verified ·
1 Parent(s): 2e56b20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -85
app.py CHANGED
@@ -1,87 +1,35 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
- from langchain.document_loaders import TextLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain.vectorstores import FAISS
8
- from langchain.llms import HuggingFacePipeline
9
- from langchain.chains import RetrievalQA
10
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
-
12
- # Load and process documents
13
- doc_loader = TextLoader("dataset.txt")
14
- docs = doc_loader.load()
15
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
16
- split_docs = text_splitter.split_documents(docs)
17
-
18
- # Create vector database
19
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
20
- vectordb = FAISS.from_documents(split_docs, embeddings)
21
-
22
- # Load model and create pipeline
23
- model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"
24
-
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- print(f"Using device: {device}")
27
-
28
- tokenizer = AutoTokenizer.from_pretrained(model_name)
29
- #model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype="auto")
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_name,
32
- device_map=device,
33
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
34
- )
35
- #model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
36
- qa_pipeline = pipeline(
37
- "text-generation",
38
- model=model,
39
- tokenizer=tokenizer,
40
- max_new_tokens=1500,
41
- pad_token_id=tokenizer.eos_token_id
42
- )
43
-
44
- # Set up LangChain
45
- llm = HuggingFacePipeline(pipeline=qa_pipeline)
46
- retriever = vectordb.as_retriever(search_kwargs={"k": 5})
47
- qa_chain = RetrievalQA.from_chain_type(
48
- retriever=retriever,
49
- chain_type="stuff",
50
- llm=llm,
51
- return_source_documents=False
52
- )
53
-
54
- @spaces.GPU
55
- def preprocess_query(query):
56
- if "script" in query or "code" in query.lower():
57
- return f"Write a CPSL script: {query}"
58
- return query
59
-
60
- @spaces.GPU
61
- def clean_response(response):
62
- result = response.get("result", "")
63
- if "Answer:" in result:
64
- return result.split("Answer:")[1].strip()
65
- return result.strip()
66
-
67
- @spaces.GPU
68
- def chatbot_response(user_input):
69
- processed_query = preprocess_query(user_input)
70
- raw_response = qa_chain.invoke({"query": processed_query})
71
- return clean_response(raw_response)
72
-
73
- with gr.Blocks() as demo: # Removed @spaces.GPU here
74
- gr.Markdown("# CPSL Chatbot")
75
- chat_history = gr.Chatbot()
76
- user_input = gr.Textbox(label="Your Message:")
77
- send_button = gr.Button("Send")
78
-
79
- @spaces.GPU
80
- def interact(user_message, history):
81
- bot_reply = chatbot_response(user_message)
82
- history.append((user_message, bot_reply))
83
- return history, history
84
-
85
- send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])
86
-
87
- demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
4
+ from PIL import Image
5
+
6
+ # Load model and processor
7
+ model_id = "google/paligemma2-28b-mix-448"
8
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval()
9
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
10
+
11
+ def generate_description(image, prompt):
12
+ if image is None:
13
+ return "Please upload an image."
14
+
15
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
16
+ input_len = model_inputs["input_ids"].shape[-1]
17
+
18
+ with torch.inference_mode():
19
+ generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
20
+ generation = generation[0][input_len:]
21
+ decoded = processor.decode(generation, skip_special_tokens=True)
22
+
23
+ return decoded
24
+
25
+ # Gradio UI
26
+ with gr.Blocks() as demo:
27
+ gr.Markdown("# PaliGemma Image Captioning")
28
+ image_input = gr.Image(type="pil", label="Upload Image")
29
+ prompt_input = gr.Textbox(label="Enter Prompt", value="describe en")
30
+ output_text = gr.Textbox(label="Generated Description")
31
+ submit_button = gr.Button("Generate")
32
+
33
+ submit_button.click(generate_description, inputs=[image_input, prompt_input], outputs=output_text)
34
+
35
+ demo.launch()