FlameF0X commited on
Commit
18a753b
·
verified ·
1 Parent(s): fd28877

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -2,13 +2,18 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
 
 
 
5
  # Load model and tokenizer
 
6
  model = AutoModelForCausalLM.from_pretrained(
7
  "FlameF0X/SnowflakeCore-G1-Tiny2",
8
  trust_remote_code=True,
9
  force_download=True,
10
  use_safetensors=True,
11
- )
12
  tokenizer = AutoTokenizer.from_pretrained(
13
  "FlameF0X/SnowflakeCore-G1-Tiny2",
14
  trust_remote_code=True,
@@ -17,11 +22,17 @@ tokenizer = AutoTokenizer.from_pretrained(
17
  )
18
 
19
  def custom_greedy_generate(prompt, max_length=50):
 
 
 
 
20
  model.eval()
21
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
22
  generated = input_ids
23
  with torch.no_grad():
24
  for _ in range(max_length):
 
25
  outputs = model(input_ids=generated)
26
  next_token_logits = outputs["logits"][:, -1, :]
27
  next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
@@ -31,15 +42,20 @@ def custom_greedy_generate(prompt, max_length=50):
31
  return tokenizer.decode(generated[0], skip_special_tokens=True)
32
 
33
  def gradio_generate(prompt):
 
 
 
34
  return custom_greedy_generate(prompt)
35
 
 
36
  iface = gr.Interface(
37
  fn=gradio_generate,
38
  inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
39
  outputs=gr.Textbox(label="Generated Text"),
40
  title="SnowflakeCore-G1-Tiny2 Text Generation",
41
- description="Enter a prompt and generate text using the SnowflakeCore-G1-Tiny2 model.",
42
  )
43
 
 
44
  if __name__ == "__main__":
45
- iface.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Determine the device to use (GPU if available, otherwise CPU)
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ print(f"Using device: {device}")
8
+
9
  # Load model and tokenizer
10
+ # Move the model to the determined device
11
  model = AutoModelForCausalLM.from_pretrained(
12
  "FlameF0X/SnowflakeCore-G1-Tiny2",
13
  trust_remote_code=True,
14
  force_download=True,
15
  use_safetensors=True,
16
+ ).to(device) # Move model to GPU or CPU
17
  tokenizer = AutoTokenizer.from_pretrained(
18
  "FlameF0X/SnowflakeCore-G1-Tiny2",
19
  trust_remote_code=True,
 
22
  )
23
 
24
  def custom_greedy_generate(prompt, max_length=50):
25
+ """
26
+ Generates text using a custom greedy decoding approach.
27
+ The model and input tensors are moved to the appropriate device (GPU/CPU).
28
+ """
29
  model.eval()
30
+ # Move input_ids to the same device as the model
31
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
32
  generated = input_ids
33
  with torch.no_grad():
34
  for _ in range(max_length):
35
+ # Ensure the generated tensor is on the correct device for model input
36
  outputs = model(input_ids=generated)
37
  next_token_logits = outputs["logits"][:, -1, :]
38
  next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
 
42
  return tokenizer.decode(generated[0], skip_special_tokens=True)
43
 
44
  def gradio_generate(prompt):
45
+ """
46
+ Wrapper function for Gradio interface.
47
+ """
48
  return custom_greedy_generate(prompt)
49
 
50
+ # Create the Gradio interface
51
  iface = gr.Interface(
52
  fn=gradio_generate,
53
  inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
54
  outputs=gr.Textbox(label="Generated Text"),
55
  title="SnowflakeCore-G1-Tiny2 Text Generation",
56
+ description=f"Enter a prompt and generate text using the SnowflakeCore-G1-Tiny2 model. Running on: {device}",
57
  )
58
 
59
+ # Launch the Gradio application
60
  if __name__ == "__main__":
61
+ iface.launch()