|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
def mean_pooling(model_output, attention_mask): |
|
token_embeddings = model_output[ |
|
0 |
|
] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
input_mask_expanded.sum(1), min=1e-9 |
|
) |
|
|
|
def cosine_similarity(u, v): |
|
return F.cosine_similarity(u, v, dim=1) |
|
|
|
|
|
def compare(text1, text2): |
|
|
|
sentences = [text1, text2] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("dmlls/all-mpnet-base-v2-negation") |
|
model = AutoModel.from_pretrained("dmlls/all-mpnet-base-v2-negation") |
|
|
|
encoded_input = tokenizer( |
|
sentences, padding=True, truncation=True, return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
model_output = model(**encoded_input) |
|
|
|
sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) |
|
|
|
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) |
|
|
|
similarity_score = cosine_similarity( |
|
sentence_embeddings[0].unsqueeze(0), sentence_embeddings[1].unsqueeze(0) |
|
) |
|
return similarity_score.item() |
|
|
|
|
|
|
|
from fastapi import FastAPI |
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|
|
|
|
|
|
from transformers import pipeline |
|
|
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
|
|
def Summerized_Text(text): |
|
text = text.strip() |
|
a = summarizer(text, max_length=130, min_length=30, do_sample=False) |
|
print(a) |
|
return a[0]['summary_text'] |
|
|
|
|
|
|
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel |
|
from fastapi import FastAPI |
|
|
|
class StrRequest(BaseModel): |
|
text: str |
|
|
|
|
|
class CompareRequest(BaseModel): |
|
summary: str |
|
text: str |
|
|
|
|
|
@app.get("/api/check") |
|
def check_connection(): |
|
try: |
|
return JSONResponse( |
|
{"status": 200, "message": "Message Successfully Sent"}, status_code=200 |
|
) |
|
except Exception as e: |
|
print("Error => ", e) |
|
return JSONResponse({"status": 500, "message": str(e)}, status_code=500) |
|
|
|
|
|
@app.post("/api/summerized") |
|
async def get_summerized(request: StrRequest): |
|
try: |
|
print(request) |
|
text = request.text |
|
if not text: |
|
return JSONResponse( |
|
{"status": 422, "message": "Invalid Input"}, status_code=422 |
|
) |
|
summary = Summerized_Text(text) |
|
if "No abstract text." in summary: |
|
return JSONResponse( |
|
{"status": 500, "message": "No matching text found", "data": "None"} |
|
) |
|
|
|
if not summary: |
|
return JSONResponse( |
|
{"status": 500, "message": "No matching text found", "data": {}} |
|
) |
|
|
|
return JSONResponse( |
|
{"status": 200, "message": "Matching text found", "data": summary} |
|
) |
|
|
|
except Exception as e: |
|
print("Error => ", e) |
|
return JSONResponse({"status": 500, "message": str(e)}, status_code=500) |
|
|
|
|
|
@app.post("/api/compare") |
|
def compareTexts(request: CompareRequest): |
|
try: |
|
text = request.text |
|
summary = request.summary |
|
if not summary or not text: |
|
return JSONResponse( |
|
{"status": 422, "message": "Invalid Input"}, status_code=422 |
|
) |
|
value = compare(text, summary) |
|
return JSONResponse( |
|
{ |
|
"status": 200, |
|
"message": "Comparisons made", |
|
"value": value, |
|
"text": text, |
|
"summary": summary, |
|
} |
|
) |
|
except Exception as e: |
|
print("Error => ", e) |
|
return JSONResponse({"status": 500, "message": str(e)}, status_code=500) |
|
|
|
|