healthtechbrasil commited on
Commit
9eab4dd
·
1 Parent(s): 6952db2

force cpu usage

Browse files
Files changed (1) hide show
  1. app.py +29 -20
app.py CHANGED
@@ -2,6 +2,11 @@ from fastapi import FastAPI
2
  from transformers import AutoTokenizer, T5ForConditionalGeneration
3
  import json
4
  import os
 
 
 
 
 
5
 
6
  app = FastAPI()
7
 
@@ -9,30 +14,35 @@ app = FastAPI()
9
  try:
10
  with open("questions.json", "r", encoding="utf-8") as f:
11
  examples = json.load(f)
 
12
  except FileNotFoundError:
13
  examples = []
 
14
 
15
- # Função para carregar o modelo e tokenizer
16
- def load_model():
17
- try:
18
- tokenizer = AutoTokenizer.from_pretrained(
19
- "unicamp-dl/ptt5-base-portuguese-vocab",
20
- legacy=False, # Usa o novo comportamento do tokenizer
21
- clean_up_tokenization_spaces=True # Define explicitamente para evitar warnings futuros
22
- )
23
- model = T5ForConditionalGeneration.from_pretrained(
24
- "unicamp-dl/ptt5-base-portuguese-vocab",
25
- device_map="auto" if os.getenv("HF_TOKEN") else None
26
- )
27
- return {"tokenizer": tokenizer, "model": model}
28
- except Exception as e:
29
- print(f"Erro ao carregar o modelo: {e}")
30
- return None
31
-
32
- # Inicializa o modelo e tokenizer
33
- model_data = load_model()
 
 
34
 
35
  def generate_question_from_prompt(theme, difficulty, example_question=None):
 
36
  if not model_data or not model_data["tokenizer"] or not model_data["model"]:
37
  return {"question": "Erro: Modelo ou tokenizer não carregado.", "options": [], "answer": "", "explanation": "Por favor, verifique os logs."}
38
 
@@ -65,7 +75,6 @@ def generate_question_from_prompt(theme, difficulty, example_question=None):
65
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
  outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
67
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
- # Parseia a resposta para extrair os componentes
69
  parts = response.split("Alternativas:")
70
  if len(parts) > 1:
71
  question_part = parts[0].replace("Enunciado clínico:", "").strip()
 
2
  from transformers import AutoTokenizer, T5ForConditionalGeneration
3
  import json
4
  import os
5
+ import logging
6
+
7
+ # Configura logging para capturar mais detalhes
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
  app = FastAPI()
12
 
 
14
  try:
15
  with open("questions.json", "r", encoding="utf-8") as f:
16
  examples = json.load(f)
17
+ logger.info("questions.json carregado com sucesso.")
18
  except FileNotFoundError:
19
  examples = []
20
+ logger.warning("questions.json não encontrado, usando lista vazia.")
21
 
22
+ # Função para carregar o modelo e tokenizer sob demanda
23
+ def get_model():
24
+ if not hasattr(get_model, "model_data"):
25
+ logger.info("Carregando modelo e tokenizer pela primeira vez...")
26
+ try:
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ "unicamp-dl/ptt5-base-portuguese-vocab",
29
+ legacy=False,
30
+ clean_up_tokenization_spaces=True
31
+ )
32
+ logger.info("Tokenizer carregado com sucesso.")
33
+ model = T5ForConditionalGeneration.from_pretrained(
34
+ "unicamp-dl/ptt5-base-portuguese-vocab",
35
+ device_map="cpu" # Força uso da CPU
36
+ )
37
+ logger.info("Modelo carregado com sucesso.")
38
+ get_model.model_data = {"tokenizer": tokenizer, "model": model}
39
+ except Exception as e:
40
+ logger.error(f"Erro ao carregar modelo ou tokenizer: {e}")
41
+ get_model.model_data = None
42
+ return get_model.model_data
43
 
44
  def generate_question_from_prompt(theme, difficulty, example_question=None):
45
+ model_data = get_model()
46
  if not model_data or not model_data["tokenizer"] or not model_data["model"]:
47
  return {"question": "Erro: Modelo ou tokenizer não carregado.", "options": [], "answer": "", "explanation": "Por favor, verifique os logs."}
48
 
 
75
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
76
  outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
77
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
78
  parts = response.split("Alternativas:")
79
  if len(parts) > 1:
80
  question_part = parts[0].replace("Enunciado clínico:", "").strip()