mramazan commited on
Commit
d9ab129
·
verified ·
1 Parent(s): 005ac7b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +566 -0
app.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import pickle
4
+ import json
5
+ import gc
6
+ import torch
7
+ from pathlib import Path
8
+ import gdown
9
+ import os
10
+ import difflib
11
+ from datetime import datetime
12
+ import random
13
+
14
+ # Import your existing modules
15
+ from utils import *
16
+ from options import args
17
+ from models import model_factory
18
+
19
+ class LazyDict:
20
+ def __init__(self, file_path):
21
+ self.file_path = file_path
22
+ self._data = None
23
+ self._loaded = False
24
+
25
+ def _load_data(self):
26
+ if not self._loaded:
27
+ try:
28
+ with open(self.file_path, "r", encoding="utf-8") as file:
29
+ self._data = json.load(file)
30
+ self._loaded = True
31
+ except Exception as e:
32
+ print(f"Warning: Could not load {self.file_path}: {str(e)}")
33
+ self._data = {}
34
+ self._loaded = True
35
+
36
+ def get(self, key, default=None):
37
+ self._load_data()
38
+ return self._data.get(key, default)
39
+
40
+ def __contains__(self, key):
41
+ self._load_data()
42
+ return key in self._data
43
+
44
+ def items(self):
45
+ self._load_data()
46
+ return self._data.items()
47
+
48
+ def keys(self):
49
+ self._load_data()
50
+ return self._data.keys()
51
+
52
+ def __len__(self):
53
+ self._load_data()
54
+ return len(self._data)
55
+
56
+ class AnimeRecommendationSystem:
57
+ def __init__(self, checkpoint_path, dataset_path, animes_path, images_path, mal_urls_path, type_seq_path, genres_path):
58
+ self.model = None
59
+ self.dataset = None
60
+ self.checkpoint_path = checkpoint_path
61
+ self.dataset_path = dataset_path
62
+ self.animes_path = animes_path
63
+
64
+ # Lazy loading ile memory optimization
65
+ self.id_to_anime = LazyDict(animes_path)
66
+ self.id_to_url = LazyDict(images_path)
67
+ self.id_to_mal_url = LazyDict(mal_urls_path)
68
+ self.id_to_type_seq = LazyDict(type_seq_path)
69
+ self.id_to_genres = LazyDict(genres_path)
70
+
71
+ # Cache için weak reference kullan
72
+ self._cache = {}
73
+
74
+ self.load_model_and_data()
75
+
76
+ def load_model_and_data(self):
77
+ try:
78
+ print("Loading model and data...")
79
+ args.bert_max_len = 128
80
+
81
+ # Dataset'i yükle
82
+ dataset_path = Path(self.dataset_path)
83
+ with dataset_path.open('rb') as f:
84
+ self.dataset = pickle.load(f)["smap"]
85
+
86
+ args.num_items = len(self.dataset)
87
+
88
+ # Model'i yükle
89
+ self.model = model_factory(args)
90
+ self.load_checkpoint()
91
+
92
+ # Garbage collection
93
+ gc.collect()
94
+ print("Model loaded successfully!")
95
+
96
+ except Exception as e:
97
+ print(f"Error loading model: {str(e)}")
98
+ raise e
99
+
100
+ def load_checkpoint(self):
101
+ try:
102
+ with open(self.checkpoint_path, 'rb') as f:
103
+ checkpoint = torch.load(f, map_location='cpu', weights_only=False)
104
+ self.model.load_state_dict(checkpoint['model_state_dict'])
105
+ self.model.eval()
106
+
107
+ # Checkpoint'i bellekten temizle
108
+ del checkpoint
109
+ gc.collect()
110
+
111
+ except Exception as e:
112
+ raise Exception(f"Failed to load checkpoint from {self.checkpoint_path}: {str(e)}")
113
+
114
+ def get_anime_genres(self, anime_id):
115
+ genres = self.id_to_genres.get(str(anime_id), [])
116
+ return [genre.title() for genre in genres] if genres else []
117
+
118
+ def get_anime_image_url(self, anime_id):
119
+ return self.id_to_url.get(str(anime_id), None)
120
+
121
+ def get_anime_mal_url(self, anime_id):
122
+ return self.id_to_mal_url.get(str(anime_id), None)
123
+
124
+ def _is_hentai(self, anime_id):
125
+ type_seq_info = self.id_to_type_seq.get(str(anime_id))
126
+ if not type_seq_info or len(type_seq_info) < 3:
127
+ return False
128
+ return type_seq_info[2]
129
+
130
+ def _get_type(self, anime_id):
131
+ type_seq_info = self.id_to_type_seq.get(str(anime_id))
132
+ if not type_seq_info or len(type_seq_info) < 2:
133
+ return "Unknown"
134
+ return type_seq_info[0]
135
+
136
+ def find_closest_anime(self, input_name):
137
+ """Finds the closest matching anime to the input name"""
138
+ anime_names = {}
139
+
140
+ # Collect all titles (main + alternative)
141
+ for k, v in self.id_to_anime.items():
142
+ anime_id = int(k)
143
+ if isinstance(v, list) and len(v) > 0:
144
+ # Main title
145
+ main_title = v[0]
146
+ anime_names[main_title.lower().strip()] = (anime_id, main_title)
147
+ # Alternative titles
148
+ if len(v) > 1:
149
+ for alt_title in v[1:]:
150
+ if alt_title and isinstance(alt_title, str):
151
+ alt_title_clean = alt_title.strip()
152
+ if alt_title_clean:
153
+ anime_names[alt_title_clean.lower()] = (anime_id, main_title)
154
+ else:
155
+ title = str(v).strip()
156
+ anime_names[title.lower()] = (anime_id, title)
157
+
158
+ input_lower = input_name.lower().strip()
159
+
160
+ # 1. Exact match
161
+ if input_lower in anime_names:
162
+ return anime_names[input_lower]
163
+
164
+ # 2. Substring search
165
+ for anime_name_lower, (anime_id, main_title) in anime_names.items():
166
+ if input_lower in anime_name_lower:
167
+ return (anime_id, main_title)
168
+
169
+ # 3. Fuzzy matching
170
+ anime_name_list = list(anime_names.keys())
171
+ close_matches = difflib.get_close_matches(input_lower, anime_name_list, n=1, cutoff=0.6)
172
+
173
+ if close_matches:
174
+ match = close_matches[0]
175
+ return anime_names[match]
176
+
177
+ return None
178
+
179
+ def search_animes(self, query):
180
+ """Search animes by query"""
181
+ animes = []
182
+ query_lower = query.lower() if query else ""
183
+
184
+ count = 0
185
+ for k, v in self.id_to_anime.items():
186
+ if count >= 200: # Limit for performance
187
+ break
188
+
189
+ anime_names = v if isinstance(v, list) else [v]
190
+ match_found = False
191
+
192
+ for name in anime_names:
193
+ if not query or query_lower in name.lower():
194
+ match_found = True
195
+ break
196
+
197
+ if match_found:
198
+ main_name = anime_names[0] if anime_names else "Unknown"
199
+ animes.append((int(k), main_name))
200
+ count += 1
201
+
202
+ animes.sort(key=lambda x: x[1])
203
+ return animes
204
+
205
+ def get_recommendations(self, favorite_anime_ids, num_recommendations=20, filters=None):
206
+ try:
207
+ if not favorite_anime_ids:
208
+ return [], [], "Please add some favorite animes first!"
209
+
210
+ smap = self.dataset
211
+ inverted_smap = {v: k for k, v in smap.items()}
212
+
213
+ converted_ids = []
214
+ for anime_id in favorite_anime_ids:
215
+ if anime_id in smap:
216
+ converted_ids.append(smap[anime_id])
217
+
218
+ if not converted_ids:
219
+ return [], [], "None of the selected animes are in the model vocabulary!"
220
+
221
+ # Normal recommendations
222
+ target_len = 128
223
+ padded = converted_ids + [0] * (target_len - len(converted_ids))
224
+ input_tensor = torch.tensor(padded, dtype=torch.long).unsqueeze(0)
225
+
226
+ max_predictions = min(75, len(inverted_smap))
227
+
228
+ with torch.no_grad():
229
+ logits = self.model(input_tensor)
230
+ last_logits = logits[:, -1, :]
231
+ top_scores, top_indices = torch.topk(last_logits, k=max_predictions, dim=1)
232
+
233
+ recommendations = []
234
+ scores = []
235
+
236
+ for idx, score in zip(top_indices.numpy()[0], top_scores.detach().numpy()[0]):
237
+ if idx in inverted_smap:
238
+ anime_id = inverted_smap[idx]
239
+
240
+ if anime_id in favorite_anime_ids:
241
+ continue
242
+
243
+ if str(anime_id) in self.id_to_anime:
244
+ # Filter check
245
+ if filters and not self._should_include_anime(anime_id, filters):
246
+ continue
247
+
248
+ anime_data = self.id_to_anime.get(str(anime_id))
249
+ anime_name = anime_data[0] if isinstance(anime_data, list) and len(anime_data) > 0 else str(anime_data)
250
+
251
+ image_url = self.get_anime_image_url(anime_id)
252
+ mal_url = self.get_anime_mal_url(anime_id)
253
+
254
+ recommendations.append({
255
+ 'id': anime_id,
256
+ 'name': anime_name,
257
+ 'score': float(score),
258
+ 'image_url': image_url,
259
+ 'mal_url': mal_url,
260
+ 'genres': self.get_anime_genres(anime_id),
261
+ 'type': self._get_type(anime_id)
262
+ })
263
+ scores.append(float(score))
264
+
265
+ if len(recommendations) >= num_recommendations:
266
+ break
267
+
268
+ # Memory cleanup
269
+ del logits, last_logits, top_scores, top_indices
270
+ gc.collect()
271
+
272
+ return recommendations, scores, f"Found {len(recommendations)} recommendations!"
273
+
274
+ except Exception as e:
275
+ return [], [], f"Error during prediction: {str(e)}"
276
+
277
+ def _should_include_anime(self, anime_id, filters):
278
+ """Check if anime should be included based on filters"""
279
+ if not filters:
280
+ return True
281
+
282
+ type_seq_info = self.id_to_type_seq.get(str(anime_id))
283
+ if not type_seq_info or len(type_seq_info) < 2:
284
+ return True
285
+
286
+ anime_type = type_seq_info[0]
287
+ is_sequel = type_seq_info[1] if len(type_seq_info) > 1 else False
288
+ is_hentai = type_seq_info[2] if len(type_seq_info) > 2 else False
289
+
290
+ # Hentai filter
291
+ if not filters.get('show_hentai', True) and is_hentai:
292
+ return False
293
+
294
+ # Sequel filter
295
+ if not filters.get('show_sequels', True) and is_sequel:
296
+ return False
297
+
298
+ # Type filters
299
+ if not filters.get('show_movies', True) and anime_type == 'MOVIE':
300
+ return False
301
+ if not filters.get('show_tv', True) and anime_type == 'TV':
302
+ return False
303
+ if not filters.get('show_ova', True) and anime_type in ['ONA', 'OVA', 'SPECIAL']:
304
+ return False
305
+
306
+ return True
307
+
308
+ # Global recommendation system
309
+ recommendation_system = None
310
+
311
+ def initialize_system():
312
+ global recommendation_system
313
+ if recommendation_system is None:
314
+ try:
315
+ args.num_items = 12689
316
+
317
+ file_ids = {
318
+ "1C6mdjblhiWGhRgbIk5DP2XCc4ElS9x8p": "pretrained_bert.pth",
319
+ "1U42cFrdLFT8NVNikT9C5SD9aAux7a5U2": "animes.json",
320
+ "1s-8FM1Wi2wOWJ9cstvm-O1_6XculTcTG": "dataset.pkl",
321
+ "1SOm1llcTKfhr-RTHC0dhaZ4AfWPs8wRx": "id_to_url.json",
322
+ "1vwJEMEOIYwvCKCCbbeaP0U_9L3NhvBzg": "anime_to_malurl.json",
323
+ "1_TyzON6ie2CqvzVNvPyc9prMTwLMefdu": "anime_to_typenseq.json",
324
+ "1G9O_ahyuJ5aO0cwoVnIXrlzMqjKrf2aw": "id_to_genres.json"
325
+ }
326
+
327
+ def download_from_gdrive(file_id, output_path):
328
+ url = f"https://drive.google.com/uc?id={file_id}"
329
+ try:
330
+ print(f"Downloading: {output_path}")
331
+ gdown.download(url, output_path, quiet=False)
332
+ print(f"Downloaded: {output_path}")
333
+ return True
334
+ except Exception as e:
335
+ print(f"Error downloading {output_path}: {e}")
336
+ return False
337
+
338
+ for file_id, filename in file_ids.items():
339
+ if not os.path.isfile(filename):
340
+ download_from_gdrive(file_id, filename)
341
+
342
+ recommendation_system = AnimeRecommendationSystem(
343
+ "pretrained_bert.pth",
344
+ "dataset.pkl",
345
+ "animes.json",
346
+ "id_to_url.json",
347
+ "anime_to_malurl.json",
348
+ "anime_to_typenseq.json",
349
+ "id_to_genres.json"
350
+ )
351
+ print("Recommendation system initialized successfully!")
352
+
353
+ except Exception as e:
354
+ print(f"Failed to initialize recommendation system: {e}")
355
+ return f"Error: {str(e)}"
356
+
357
+ return "System ready!"
358
+
359
+ def search_and_add_anime(query, favorites_state):
360
+ """Search anime and return search results"""
361
+ if not recommendation_system:
362
+ return "System not initialized", favorites_state, ""
363
+
364
+ if not query.strip():
365
+ return "Please enter an anime name to search", favorites_state, ""
366
+
367
+ # Search for anime
368
+ result = recommendation_system.find_closest_anime(query.strip())
369
+
370
+ if result:
371
+ anime_id, anime_name = result
372
+
373
+ # Check if already in favorites
374
+ if anime_id in favorites_state:
375
+ return f"'{anime_name}' is already in your favorites", favorites_state, ""
376
+
377
+ # Add to favorites
378
+ if len(favorites_state) >= 15:
379
+ return "Maximum 15 favorite animes allowed", favorites_state, ""
380
+
381
+ favorites_state.append(anime_id)
382
+ return f"Added '{anime_name}' to favorites", favorites_state, ""
383
+ else:
384
+ return f"No anime found matching '{query}'", favorites_state, ""
385
+
386
+ def get_favorites_display(favorites_state):
387
+ """Get display string for favorites"""
388
+ if not favorites_state or not recommendation_system:
389
+ return "No favorites added yet"
390
+
391
+ display = "Your Favorite Animes:\n"
392
+ for i, anime_id in enumerate(favorites_state, 1):
393
+ anime_data = recommendation_system.id_to_anime.get(str(anime_id))
394
+ if anime_data:
395
+ anime_name = anime_data[0] if isinstance(anime_data, list) else str(anime_data)
396
+ display += f"{i}. {anime_name}\n"
397
+
398
+ return display
399
+
400
+ def clear_favorites(favorites_state):
401
+ """Clear all favorites"""
402
+ return "Favorites cleared", [], ""
403
+
404
+ def get_recommendations_gradio(favorites_state, num_recs, show_hentai, show_sequels, show_movies, show_tv, show_ova):
405
+ """Get recommendations for Gradio interface"""
406
+ if not recommendation_system:
407
+ return "System not initialized"
408
+
409
+ if not favorites_state:
410
+ return "Please add some favorite animes first!"
411
+
412
+ # Prepare filters
413
+ filters = {
414
+ 'show_hentai': show_hentai,
415
+ 'show_sequels': show_sequels,
416
+ 'show_movies': show_movies,
417
+ 'show_tv': show_tv,
418
+ 'show_ova': show_ova
419
+ }
420
+
421
+ recommendations, scores, message = recommendation_system.get_recommendations(
422
+ favorites_state,
423
+ num_recommendations=int(num_recs),
424
+ filters=filters
425
+ )
426
+
427
+ if not recommendations:
428
+ return f"No recommendations found. {message}"
429
+
430
+ # Format recommendations
431
+ result = f"**{message}**\n\n"
432
+
433
+ for i, rec in enumerate(recommendations, 1):
434
+ result += f"**{i}. {rec['name']}**\n"
435
+ result += f"Score: {rec['score']:.4f}\n"
436
+ result += f"Type: {rec.get('type', 'Unknown')}\n"
437
+
438
+ if rec['genres']:
439
+ result += f"Genres: {', '.join(rec['genres'])}\n"
440
+
441
+ if rec.get('mal_url'):
442
+ result += f"[MyAnimeList Link]({rec['mal_url']})\n"
443
+
444
+ result += "\n" + "-"*50 + "\n\n"
445
+
446
+ return result
447
+
448
+ def create_interface():
449
+ # Initialize system
450
+ init_status = initialize_system()
451
+ print(init_status)
452
+
453
+ with gr.Blocks(title="Anime Recommendation System", theme=gr.themes.Soft()) as demo:
454
+ # State for favorites
455
+ favorites_state = gr.State([])
456
+
457
+ gr.HTML("""
458
+ <div style="text-align: center; margin-bottom: 20px;">
459
+ <h1>🎌 Anime Recommendation System</h1>
460
+ <p>Add your favorite animes and get personalized recommendations!</p>
461
+ </div>
462
+ """)
463
+
464
+ with gr.Tab("Add Favorites"):
465
+ with gr.Row():
466
+ with gr.Column(scale=2):
467
+ search_input = gr.Textbox(
468
+ label="Search Anime",
469
+ placeholder="Enter anime name (e.g., 'Mushoku Tensei', 'Attack on Titan')",
470
+ lines=1
471
+ )
472
+
473
+ with gr.Row():
474
+ add_btn = gr.Button("Add to Favorites", variant="primary")
475
+ clear_btn = gr.Button("Clear All Favorites", variant="secondary")
476
+
477
+ with gr.Column(scale=2):
478
+ status_output = gr.Textbox(label="Status", lines=2)
479
+ favorites_display = gr.Textbox(
480
+ label="Your Favorites",
481
+ lines=10,
482
+ interactive=False,
483
+ value="No favorites added yet"
484
+ )
485
+
486
+ with gr.Tab("Get Recommendations"):
487
+ with gr.Row():
488
+ with gr.Column(scale=1):
489
+ gr.Markdown("### Recommendation Settings")
490
+
491
+ num_recs = gr.Slider(
492
+ minimum=5,
493
+ maximum=50,
494
+ value=20,
495
+ step=5,
496
+ label="Number of Recommendations"
497
+ )
498
+
499
+ gr.Markdown("### Filters")
500
+ show_movies = gr.Checkbox(label="Include Movies", value=True)
501
+ show_tv = gr.Checkbox(label="Include TV Series", value=True)
502
+ show_ova = gr.Checkbox(label="Include OVA/ONA/Special", value=True)
503
+ show_sequels = gr.Checkbox(label="Include Sequels", value=True)
504
+ show_hentai = gr.Checkbox(label="Include Hentai", value=False)
505
+
506
+ recommend_btn = gr.Button("Get Recommendations", variant="primary")
507
+
508
+ with gr.Column(scale=2):
509
+ recommendations_output = gr.Markdown(
510
+ label="Recommendations",
511
+ value="Add some favorite animes and click 'Get Recommendations'"
512
+ )
513
+
514
+ # Event handlers
515
+ add_btn.click(
516
+ fn=search_and_add_anime,
517
+ inputs=[search_input, favorites_state],
518
+ outputs=[status_output, favorites_state, search_input]
519
+ ).then(
520
+ fn=get_favorites_display,
521
+ inputs=[favorites_state],
522
+ outputs=[favorites_display]
523
+ )
524
+
525
+ clear_btn.click(
526
+ fn=clear_favorites,
527
+ inputs=[favorites_state],
528
+ outputs=[status_output, favorites_state, search_input]
529
+ ).then(
530
+ fn=get_favorites_display,
531
+ inputs=[favorites_state],
532
+ outputs=[favorites_display]
533
+ )
534
+
535
+ recommend_btn.click(
536
+ fn=get_recommendations_gradio,
537
+ inputs=[
538
+ favorites_state, num_recs, show_hentai, show_sequels,
539
+ show_movies, show_tv, show_ova
540
+ ],
541
+ outputs=[recommendations_output]
542
+ )
543
+
544
+ # Examples
545
+ with gr.Tab("Examples"):
546
+ gr.Markdown("""
547
+ ### How to use:
548
+ 1. **Add Favorites**: Search and add your favorite animes
549
+ 2. **Set Filters**: Choose what types of anime to include
550
+ 3. **Get Recommendations**: Click to get personalized suggestions
551
+
552
+ ### Example Searches:
553
+ - Mushoku Tensei
554
+ - Attack on Titan
555
+ - Demon Slayer
556
+ - Your Name
557
+ - Spirited Away
558
+ - One Piece
559
+ - Naruto
560
+ """)
561
+
562
+ return demo
563
+
564
+ if __name__ == "__main__":
565
+ demo = create_interface()
566
+ demo.launch(server_name="0.0.0.0", server_port=7860)