Spaces:
Runtime error
Runtime error
Update logger
Browse files
model.py
CHANGED
|
@@ -80,10 +80,10 @@ formatter = logging.Formatter(
|
|
| 80 |
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
| 81 |
datefmt='%Y-%m-%d %H:%M:%S')
|
| 82 |
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
| 83 |
-
stream_handler.setLevel(logging.
|
| 84 |
stream_handler.setFormatter(formatter)
|
| 85 |
logger = logging.getLogger(__name__)
|
| 86 |
-
logger.setLevel(logging.
|
| 87 |
logger.propagate = False
|
| 88 |
logger.addHandler(stream_handler)
|
| 89 |
|
|
@@ -254,7 +254,7 @@ class Model:
|
|
| 254 |
self.style = style
|
| 255 |
self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
|
| 256 |
self.query_template = self.args.query_template
|
| 257 |
-
logger.
|
| 258 |
|
| 259 |
self.strategy.temperature = self.args.temp_all_gen
|
| 260 |
|
|
@@ -296,7 +296,7 @@ class Model:
|
|
| 296 |
start = time.perf_counter()
|
| 297 |
|
| 298 |
text = self.query_template.format(text)
|
| 299 |
-
logger.
|
| 300 |
seq = tokenizer.encode(text)
|
| 301 |
logger.info(f'{len(seq)=}')
|
| 302 |
if len(seq) > 110:
|
|
@@ -342,7 +342,7 @@ class Model:
|
|
| 342 |
output_list.append(coarse_samples)
|
| 343 |
remaining -= self.max_batch_size
|
| 344 |
output_tokens = torch.cat(output_list, dim=0)
|
| 345 |
-
logger.
|
| 346 |
|
| 347 |
elapsed = time.perf_counter() - start
|
| 348 |
logger.info(f'Elapsed: {elapsed}')
|
|
@@ -360,7 +360,7 @@ class Model:
|
|
| 360 |
logger.info('--- generate_images ---')
|
| 361 |
start = time.perf_counter()
|
| 362 |
|
| 363 |
-
logger.
|
| 364 |
res = []
|
| 365 |
if self.only_first_stage:
|
| 366 |
for i in range(len(tokens)):
|
|
@@ -414,6 +414,9 @@ class AppModel(Model):
|
|
| 414 |
self, text: str, translate: bool, style: str, seed: int,
|
| 415 |
only_first_stage: bool, num: int
|
| 416 |
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
|
|
|
|
|
|
|
|
|
|
| 417 |
if translate:
|
| 418 |
text = translated_text = self.translator(text)
|
| 419 |
else:
|
|
|
|
| 80 |
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
| 81 |
datefmt='%Y-%m-%d %H:%M:%S')
|
| 82 |
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
| 83 |
+
stream_handler.setLevel(logging.INFO)
|
| 84 |
stream_handler.setFormatter(formatter)
|
| 85 |
logger = logging.getLogger(__name__)
|
| 86 |
+
logger.setLevel(logging.INFO)
|
| 87 |
logger.propagate = False
|
| 88 |
logger.addHandler(stream_handler)
|
| 89 |
|
|
|
|
| 254 |
self.style = style
|
| 255 |
self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
|
| 256 |
self.query_template = self.args.query_template
|
| 257 |
+
logger.debug(f'{self.query_template=}')
|
| 258 |
|
| 259 |
self.strategy.temperature = self.args.temp_all_gen
|
| 260 |
|
|
|
|
| 296 |
start = time.perf_counter()
|
| 297 |
|
| 298 |
text = self.query_template.format(text)
|
| 299 |
+
logger.debug(f'{text=}')
|
| 300 |
seq = tokenizer.encode(text)
|
| 301 |
logger.info(f'{len(seq)=}')
|
| 302 |
if len(seq) > 110:
|
|
|
|
| 342 |
output_list.append(coarse_samples)
|
| 343 |
remaining -= self.max_batch_size
|
| 344 |
output_tokens = torch.cat(output_list, dim=0)
|
| 345 |
+
logger.debug(f'{output_tokens.shape=}')
|
| 346 |
|
| 347 |
elapsed = time.perf_counter() - start
|
| 348 |
logger.info(f'Elapsed: {elapsed}')
|
|
|
|
| 360 |
logger.info('--- generate_images ---')
|
| 361 |
start = time.perf_counter()
|
| 362 |
|
| 363 |
+
logger.debug(f'{self.only_first_stage=}')
|
| 364 |
res = []
|
| 365 |
if self.only_first_stage:
|
| 366 |
for i in range(len(tokens)):
|
|
|
|
| 414 |
self, text: str, translate: bool, style: str, seed: int,
|
| 415 |
only_first_stage: bool, num: int
|
| 416 |
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
|
| 417 |
+
logger.info(
|
| 418 |
+
f'{text=}, {translate=}, {style=}, {seed=}, {only_first_stage=}, {num=}'
|
| 419 |
+
)
|
| 420 |
if translate:
|
| 421 |
text = translated_text = self.translator(text)
|
| 422 |
else:
|