Spaces:
Runtime error
Runtime error
Upload tool
Browse files- app.py +6 -0
- requirements.txt +1 -0
- tool.py +43 -0
app.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import launch_gradio_demo
|
2 |
+
from tool import TranslationTool
|
3 |
+
|
4 |
+
tool = TranslationTool()
|
5 |
+
|
6 |
+
launch_gradio_demo(tool)
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
smolagents
|
tool.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional
|
2 |
+
from smolagents.tools import Tool
|
3 |
+
|
4 |
+
class TranslationTool(Tool):
|
5 |
+
"""
|
6 |
+
Example:
|
7 |
+
|
8 |
+
```py
|
9 |
+
translator = TranslationTool()
|
10 |
+
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
|
11 |
+
```
|
12 |
+
"""
|
13 |
+
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
14 |
+
description = "This is a tool that translates text from a language to another."
|
15 |
+
name = "translator"
|
16 |
+
inputs = {'text': {'type': 'string', 'description': 'The text to translate'}, 'src_lang': {'type': 'string', 'description': "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'"}, 'tgt_lang': {'type': 'string', 'description': "The language for the desired output language. Written in plain English, such as 'Romanian', or 'Albanian'"}}
|
17 |
+
output_type = "string"
|
18 |
+
|
19 |
+
def __init__(self, lang_to_code=LANGUAGE_CODES, pre_processor_class=AutoTokenizer, model_class=AutoModelForSeq2SeqLM):
|
20 |
+
super().__init__()
|
21 |
+
self.lang_to_code = lang_to_code
|
22 |
+
self.pre_processor_class = pre_processor_class
|
23 |
+
self.model_class = model_class
|
24 |
+
# self.pre_processor = self.pre_processor_class.from_pretrained(self.default_checkpoint)
|
25 |
+
# self.model = self.model_class.from_pretrained(self.default_checkpoint)
|
26 |
+
# self.post_processor = self.pre_processor
|
27 |
+
|
28 |
+
def encode(self, text, src_lang, tgt_lang):
|
29 |
+
if src_lang not in self.lang_to_code:
|
30 |
+
raise ValueError(f"{src_lang} is not a supported language.")
|
31 |
+
if tgt_lang not in self.lang_to_code:
|
32 |
+
raise ValueError(f"{tgt_lang} is not a supported language.")
|
33 |
+
src_lang = self.lang_to_code[src_lang]
|
34 |
+
tgt_lang = self.lang_to_code[tgt_lang]
|
35 |
+
return self.pre_processor._build_translation_inputs(
|
36 |
+
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, inputs):
|
40 |
+
return self.model.generate(**inputs)
|
41 |
+
|
42 |
+
def decode(self, outputs):
|
43 |
+
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
|