elucidator8918 commited on
Commit
352251d
·
verified ·
1 Parent(s): 998e8ac

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +26 -8
tasks/text.py CHANGED
@@ -9,7 +9,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -18,9 +18,7 @@ async def evaluate_text(request: TextEvaluationRequest):
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
 
21
- Current Model: Random Baseline
22
- - Makes random predictions from the label space (0-7)
23
- - Used as a baseline for comparison
24
  """
25
  # Get space info
26
  username, space_url = get_space_info()
@@ -46,6 +44,18 @@ async def evaluate_text(request: TextEvaluationRequest):
46
  # Split dataset
47
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
48
  test_dataset = train_test["test"]
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Start tracking emissions
51
  tracker.start()
@@ -56,15 +66,23 @@ async def evaluate_text(request: TextEvaluationRequest):
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
 
59
- # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
 
 
 
 
 
 
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
65
  #--------------------------------------------------------------------------------------------
66
 
67
-
68
  # Stop tracking emissions
69
  emissions_data = tracker.stop_task()
70
 
 
9
 
10
  router = APIRouter()
11
 
12
+ DESCRIPTION = "GTE Architecture"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
 
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
 
21
+ Current Model: GTE Architecture
 
 
22
  """
23
  # Get space info
24
  username, space_url = get_space_info()
 
44
  # Split dataset
45
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
46
  test_dataset = train_test["test"]
47
+
48
+ true_labels = test_dataset["label"]
49
+ texts = test_dataset["quote"]
50
+
51
+ model_repo = "elucidator8918/frugal-ai-text"
52
+ config = AutoConfig.from_pretrained(model_repo)
53
+ model = AutoModelForSequenceClassification.from_pretrained(model_repo)
54
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
55
+
56
+ device = torch.device("cuda") if torch.cuda.is_available() torch.device("cpu")
57
+ model = model.to(device)
58
+ model.eval()
59
 
60
  # Start tracking emissions
61
  tracker.start()
 
66
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
67
  #--------------------------------------------------------------------------------------------
68
 
69
+ text_encoding = tokenizer(
70
+ texts,
71
+ truncation=True,
72
+ padding=True,
73
+ return_tensors="pt",
74
+ )
75
+
76
+ with torch.no_grad():
77
+ text_input_ids = text_encoding["input_ids"].to(device)
78
+ text_attention_mask = text_encoding["attention_mask"].to(device)
79
+ outputs = model(test_input_ids, test_attention_mask)
80
+ predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
81
 
82
  #--------------------------------------------------------------------------------------------
83
  # YOUR MODEL INFERENCE STOPS HERE
84
  #--------------------------------------------------------------------------------------------
85
 
 
86
  # Stop tracking emissions
87
  emissions_data = tracker.stop_task()
88