FlameF0X commited on
Commit
3c887d1
·
verified ·
1 Parent(s): 7ca4acf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Load model and tokenizer
6
+ model = AutoModelForCausalLM.from_pretrained(
7
+ "FlameF0X/SnowflakeCore-G1-Tiny",
8
+ trust_remote_code=True,
9
+ force_download=True,
10
+ use_safetensors=True,
11
+ )
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ "FlameF0X/SnowflakeCore-G1-Tiny",
14
+ trust_remote_code=True,
15
+ force_download=True,
16
+ use_safetensors=True,
17
+ )
18
+
19
+ def custom_greedy_generate(prompt, max_length=50):
20
+ model.eval()
21
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
22
+ generated = input_ids
23
+ with torch.no_grad():
24
+ for _ in range(max_length):
25
+ outputs = model(input_ids=generated)
26
+ next_token_logits = outputs["logits"][:, -1, :]
27
+ next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
28
+ generated = torch.cat((generated, next_token_id), dim=1)
29
+ if next_token_id.item() == tokenizer.eos_token_id:
30
+ break
31
+ return tokenizer.decode(generated[0], skip_special_tokens=True)
32
+
33
+ def gradio_generate(prompt):
34
+ return custom_greedy_generate(prompt)
35
+
36
+ iface = gr.Interface(
37
+ fn=gradio_generate,
38
+ inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
39
+ outputs=gr.Textbox(label="Generated Text"),
40
+ title="SnowflakeCore-G1-Tiny Text Generation",
41
+ description="Enter a prompt and generate text using the SnowflakeCore-G1-Tiny model.",
42
+ )
43
+
44
+ if __name__ == "__main__":
45
+ iface.launch()