Spaces:
Configuration error
Configuration error
Commit
·
5b81931
1
Parent(s):
928c735
Upload 14 files
Browse files- .env +1 -0
- .gitattributes +7 -34
- README.md +2 -12
- __pycache__/process.cpython-37.pyc +0 -0
- amharic.csv +3 -0
- app.py +51 -0
- hausa.csv +3 -0
- igbo.csv +0 -0
- news.ann +3 -0
- process.py +133 -0
- requirements.txt +16 -0
- swahili.csv +3 -0
- utils.py +32 -0
- yoruba.csv +3 -0
.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
COHERE_API_KEY = 7rMjNpj7LLTNlAcoR1Sc6cH23aURrBQoMPi9vzam
|
.gitattributes
CHANGED
|
@@ -1,34 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
amharic.csv filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
hausa.csv filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
news.ann filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
swahili.csv filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
yoruba.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,12 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
emoji: 🔥
|
| 4 |
-
colorFrom: gray
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: streamlit
|
| 7 |
-
sdk_version: 1.19.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
# cluster_news
|
| 2 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__pycache__/process.cpython-37.pyc
ADDED
|
Binary file (3.25 kB). View file
|
|
|
amharic.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59b8670c19f95f0cff667b8d5f69033e93bcdd2dec5e1cc069f82d93699da894
|
| 3 |
+
size 36144176
|
app.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from process import *
|
| 4 |
+
|
| 5 |
+
df = import_ds()
|
| 6 |
+
|
| 7 |
+
st.title('AFri News Multilingual Embedding')
|
| 8 |
+
|
| 9 |
+
form = st.form(key="user_settings")
|
| 10 |
+
|
| 11 |
+
textcontainer = st.container()
|
| 12 |
+
|
| 13 |
+
plotcontainer = st.container()
|
| 14 |
+
|
| 15 |
+
with form:
|
| 16 |
+
|
| 17 |
+
query = st.text_input('Please input your news text here:')
|
| 18 |
+
|
| 19 |
+
num_nearest = int(st.slider('Please input the number of news to find: ', value=15, min_value=1, max_value=200))
|
| 20 |
+
|
| 21 |
+
generate_button = form.form_submit_button("Cluster News")
|
| 22 |
+
|
| 23 |
+
if generate_button:
|
| 24 |
+
key = get_key()
|
| 25 |
+
|
| 26 |
+
co = cohere.Client(key)
|
| 27 |
+
|
| 28 |
+
embeddings = getEmbeddings(co,df)
|
| 29 |
+
|
| 30 |
+
indexfile = 'news.ann'
|
| 31 |
+
|
| 32 |
+
semantic_search(embeddings, indexfile)
|
| 33 |
+
|
| 34 |
+
query_embed = get_query_embed(co, query)
|
| 35 |
+
|
| 36 |
+
nearest_ids = getClosestNeighbours(indexfile, query_embed, num_nearest)
|
| 37 |
+
|
| 38 |
+
nn_embeddings = embeddings[nearest_ids[0]]
|
| 39 |
+
|
| 40 |
+
all_embeddings = np.vstack([nn_embeddings, query_embed])
|
| 41 |
+
|
| 42 |
+
umap_embeds = getUMAPEmbed(embeddings)
|
| 43 |
+
|
| 44 |
+
text_news = display_news(df,nearest_ids)
|
| 45 |
+
|
| 46 |
+
fig = plot2DChart(df, umap_embeds)
|
| 47 |
+
|
| 48 |
+
textcontainer.write(text_news)
|
| 49 |
+
|
| 50 |
+
plotcontainer.write(fig)
|
| 51 |
+
|
hausa.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5279476f52eded50fa5254c9a6be01abe1393484eb57a8858f90c6d079e520e
|
| 3 |
+
size 14590027
|
igbo.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
news.ann
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71443c486fb4dc39f3a600b705642795ad19c8ebca8e495259790e5351610b74
|
| 3 |
+
size 1603680
|
process.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#from dotenv import load_dotenv
|
| 2 |
+
from annoy import AnnoyIndex
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cohere
|
| 6 |
+
import os
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
import umap
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_key():
|
| 13 |
+
key = "7rMjNpj7LLTNlAcoR1Sc6cH23aURrBQoMPi9vzam"
|
| 14 |
+
#load_dotenv()
|
| 15 |
+
return key
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def import_ds():
|
| 19 |
+
newsfiles = ['amharic','hausa','swahili','yoruba','igbo']
|
| 20 |
+
|
| 21 |
+
df_am = pd.read_csv(f'{newsfiles[0]}.csv')
|
| 22 |
+
df_am = df_am.sample(frac=0.5)
|
| 23 |
+
#df_en = pd.read_csv(f'{newsfiles[1]}.csv')
|
| 24 |
+
#df_en = df_en.sample(frac=0.3)
|
| 25 |
+
df_hs = pd.read_csv(f'{newsfiles[1]}.csv')
|
| 26 |
+
df_hs = df_hs.sample(frac=0.5)
|
| 27 |
+
df_sw = pd.read_csv(f'{newsfiles[2]}.csv')
|
| 28 |
+
df_sw = df_sw.sample(frac=0.5)
|
| 29 |
+
df_yr = pd.read_csv(f'{newsfiles[3]}.csv')
|
| 30 |
+
df_yr = df_yr.sample(frac=0.5)
|
| 31 |
+
df_ig = pd.read_csv(f'{newsfiles[4]}.csv')
|
| 32 |
+
df_ig = df_ig.sample(frac=0.5)
|
| 33 |
+
|
| 34 |
+
df_news = pd.concat([df_am,df_hs,df_sw,df_yr,df_ig],axis=0)
|
| 35 |
+
|
| 36 |
+
df_news = df_news.sample(frac = 1)
|
| 37 |
+
|
| 38 |
+
df_news = df_news[df_news['title'].notna()]
|
| 39 |
+
|
| 40 |
+
df_news = df_news.drop_duplicates("title")
|
| 41 |
+
|
| 42 |
+
df_news = df_news.sample(500)
|
| 43 |
+
|
| 44 |
+
return df_news
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def getEmbeddings(co,df):
|
| 48 |
+
|
| 49 |
+
df['text'] = df['title'] + df['summary']
|
| 50 |
+
|
| 51 |
+
df = df.drop(['title','id','summary'],axis=1)
|
| 52 |
+
|
| 53 |
+
embeds = co.embed(texts=list(df['text']),model="multilingual-22-12",truncate="RIGHT").embeddings
|
| 54 |
+
|
| 55 |
+
embeds = np.array(embeds)
|
| 56 |
+
|
| 57 |
+
return embeds
|
| 58 |
+
|
| 59 |
+
def semantic_search(emb,indexfile):
|
| 60 |
+
|
| 61 |
+
emb = np.array(emb)
|
| 62 |
+
|
| 63 |
+
search_index = AnnoyIndex(emb.shape[1], 'angular')
|
| 64 |
+
print(emb.shape[1])
|
| 65 |
+
|
| 66 |
+
for i in range(len(emb)):
|
| 67 |
+
search_index.add_item(i, emb[i])
|
| 68 |
+
|
| 69 |
+
search_index.build(10)
|
| 70 |
+
search_index.save(indexfile)
|
| 71 |
+
|
| 72 |
+
def get_query_embed(co, query):
|
| 73 |
+
query_embed = co.embed(texts=[query],
|
| 74 |
+
model='multilingual-22-12',
|
| 75 |
+
truncate='right').embeddings
|
| 76 |
+
|
| 77 |
+
return np.array(query_embed)
|
| 78 |
+
|
| 79 |
+
def getClosestNeighbours(indexfile,query_embed,neighbours=15):
|
| 80 |
+
|
| 81 |
+
search_index = AnnoyIndex(768, 'angular')
|
| 82 |
+
search_index.load(indexfile)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Retrieve the nearest neighbors
|
| 86 |
+
similar_item_ids = search_index.get_nns_by_vector(query_embed[0],neighbours,
|
| 87 |
+
include_distances=True)
|
| 88 |
+
|
| 89 |
+
return similar_item_ids
|
| 90 |
+
|
| 91 |
+
def display_news(df,similar_item_ids):
|
| 92 |
+
# Format the results
|
| 93 |
+
#print(similar_item_ids)
|
| 94 |
+
|
| 95 |
+
results = pd.DataFrame(data={'title': df.iloc[similar_item_ids[0]]['title'],
|
| 96 |
+
'url': df.iloc[similar_item_ids[0]]['url'],
|
| 97 |
+
'summary': df.iloc[similar_item_ids[0]]['summary']})
|
| 98 |
+
#'distance': similar_item_ids[1]})
|
| 99 |
+
results.reset_index(drop=True, inplace=True)
|
| 100 |
+
|
| 101 |
+
return results
|
| 102 |
+
|
| 103 |
+
def getUMAPEmbed(embeds):
|
| 104 |
+
# Map the nearest embeddings to 2d
|
| 105 |
+
reducer = umap.UMAP(n_neighbors=20)
|
| 106 |
+
|
| 107 |
+
return reducer.fit_transform(embeds)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def plot2DChart(df, umap_embeds, clusters=None):
|
| 111 |
+
if clusters is None:
|
| 112 |
+
clusters = {}
|
| 113 |
+
|
| 114 |
+
df_viz = pd.DataFrame(data={'url': df['url'], 'title': df['title']})
|
| 115 |
+
df_viz['x'] = umap_embeds[:, 0]
|
| 116 |
+
df_viz['y'] = umap_embeds[:, 1]
|
| 117 |
+
|
| 118 |
+
#print(df_explore)
|
| 119 |
+
# Plot
|
| 120 |
+
fig = px.scatter(df_viz, x='x', y='y', hover_data=['title'])
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
fig.data = fig.data[::-1]
|
| 124 |
+
|
| 125 |
+
return fig
|
| 126 |
+
|
| 127 |
+
if __name__ == '__main__':
|
| 128 |
+
key = get_key()
|
| 129 |
+
co = cohere.Client(key)
|
| 130 |
+
df_news = import_ds()
|
| 131 |
+
embed = process(co,df_news)
|
| 132 |
+
semantic_search(embed)
|
| 133 |
+
getClosestNeighbours(df_news)
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
altair==4.2.2
|
| 2 |
+
annoy==1.17.0
|
| 3 |
+
huggingface-hub==0.14.1
|
| 4 |
+
numpy==1.21.6
|
| 5 |
+
pandas==1.3.5
|
| 6 |
+
plotly==5.14.1
|
| 7 |
+
scipy==1.7.3
|
| 8 |
+
beautifulsoup4==4.11.1
|
| 9 |
+
cohere==2.7.0
|
| 10 |
+
matplotlib==3.5.1
|
| 11 |
+
python-dotenv==0.21.0
|
| 12 |
+
scikit_learn==1.0.2
|
| 13 |
+
streamlit==1.22.0
|
| 14 |
+
streamlit_plotly_events==0.0.6
|
| 15 |
+
umap==0.1.1
|
| 16 |
+
umap_learn==0.5.3
|
swahili.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bccf0a9aaa7f5399fa51b6d34df9848f5a077a771ce2318f7f6beb58686dee99
|
| 3 |
+
size 20901981
|
utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset_builder, load_dataset
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
def inspect():
|
| 5 |
+
langs = ['amharic','english','hausa','swahili','yoruba','igbo']
|
| 6 |
+
|
| 7 |
+
for lang in langs:
|
| 8 |
+
ds_builder = load_dataset_builder("csebuetnlp/xlsum",lang)
|
| 9 |
+
|
| 10 |
+
desc = ds_builder.info.description
|
| 11 |
+
|
| 12 |
+
feat = ds_builder.info.features
|
| 13 |
+
|
| 14 |
+
return desc,feat
|
| 15 |
+
|
| 16 |
+
def load():
|
| 17 |
+
try:
|
| 18 |
+
langs = ['amharic','hausa','swahili','yoruba','igbo']
|
| 19 |
+
|
| 20 |
+
for lang in langs:
|
| 21 |
+
|
| 22 |
+
dataset = load_dataset("csebuetnlp/xlsum", lang ,split="train")
|
| 23 |
+
#for split, data in dataset.items():
|
| 24 |
+
dataset.to_csv(f"{lang}.csv", index = None)
|
| 25 |
+
#dataset.save_to_disk(lang)
|
| 26 |
+
#return dataset
|
| 27 |
+
except Exception as ex:
|
| 28 |
+
logging.debug(ex)
|
| 29 |
+
|
| 30 |
+
if __name__ == '__main__':
|
| 31 |
+
#print(inspect())
|
| 32 |
+
load()
|
yoruba.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f5df52e87acfcd2fae999e7108a4f8c5e44345070b3c41380d72c47f8fd1412
|
| 3 |
+
size 16448886
|