Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	space update
Browse files
    	
        app.py
    CHANGED
    
    | @@ -71,7 +71,23 @@ def load_model(custom_model_path=None): | |
| 71 |  | 
| 72 | 
             
                    if os.path.exists(model_path):
         | 
| 73 | 
             
                        try:
         | 
| 74 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 75 | 
             
                            u_model.eval()
         | 
| 76 | 
             
                            print("β
 Model weights loaded successfully!")
         | 
| 77 | 
             
                            return u_model, u_tokenizer, f"β
 Model loaded from: {model_path}"
         | 
| @@ -188,15 +204,13 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo: | |
| 188 | 
             
                            gr.Markdown("### π Model Upload (Optional)")
         | 
| 189 | 
             
                            model_file = gr.File(
         | 
| 190 | 
             
                                label="Upload your own model.pth file",
         | 
| 191 | 
            -
                                file_types=[".pth", ".pt"] | 
| 192 | 
            -
                                info="Upload a custom UstaModel checkpoint to use instead of the default model"
         | 
| 193 | 
             
                            )
         | 
| 194 | 
             
                            upload_btn = gr.Button("Load Model", variant="primary")
         | 
| 195 | 
             
                            model_status_display = gr.Textbox(
         | 
| 196 | 
             
                                label="Model Status",
         | 
| 197 | 
             
                                value=model_status,
         | 
| 198 | 
            -
                                interactive=False | 
| 199 | 
            -
                                info="Shows the current model loading status"
         | 
| 200 | 
             
                            )
         | 
| 201 |  | 
| 202 | 
             
                    with gr.Column(scale=1):
         | 
| @@ -205,8 +219,7 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo: | |
| 205 | 
             
                            gr.Markdown("### βοΈ Generation Settings")
         | 
| 206 | 
             
                            system_msg = gr.Textbox(
         | 
| 207 | 
             
                                value="You are Usta, a geographical knowledge assistant trained from scratch.", 
         | 
| 208 | 
            -
                                label="System message" | 
| 209 | 
            -
                                info="Note: This model focuses on geographical knowledge"
         | 
| 210 | 
             
                            )
         | 
| 211 | 
             
                            max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
         | 
| 212 | 
             
                            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
         | 
| @@ -215,8 +228,7 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo: | |
| 215 | 
             
                                maximum=1.0,
         | 
| 216 | 
             
                                value=0.95,
         | 
| 217 | 
             
                                step=0.05,
         | 
| 218 | 
            -
                                label="Top-p (nucleus sampling)" | 
| 219 | 
            -
                                info="Note: This parameter is not used by UstaModel"
         | 
| 220 | 
             
                            )
         | 
| 221 |  | 
| 222 | 
             
                # Chat interface
         | 
|  | |
| 71 |  | 
| 72 | 
             
                    if os.path.exists(model_path):
         | 
| 73 | 
             
                        try:
         | 
| 74 | 
            +
                            state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
         | 
| 75 | 
            +
                            
         | 
| 76 | 
            +
                            # Handle potential key mapping issues
         | 
| 77 | 
            +
                            if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
         | 
| 78 | 
            +
                                # Map old key names to new key names
         | 
| 79 | 
            +
                                new_state_dict = {}
         | 
| 80 | 
            +
                                for key, value in state_dict.items():
         | 
| 81 | 
            +
                                    if key == "embedding.weight":
         | 
| 82 | 
            +
                                        new_state_dict["embedding.embedding.weight"] = value
         | 
| 83 | 
            +
                                    elif key == "pos_embedding.weight":
         | 
| 84 | 
            +
                                        # Skip positional embedding if not expected
         | 
| 85 | 
            +
                                        continue
         | 
| 86 | 
            +
                                    else:
         | 
| 87 | 
            +
                                        new_state_dict[key] = value
         | 
| 88 | 
            +
                                state_dict = new_state_dict
         | 
| 89 | 
            +
                            
         | 
| 90 | 
            +
                            u_model.load_state_dict(state_dict)
         | 
| 91 | 
             
                            u_model.eval()
         | 
| 92 | 
             
                            print("β
 Model weights loaded successfully!")
         | 
| 93 | 
             
                            return u_model, u_tokenizer, f"β
 Model loaded from: {model_path}"
         | 
|  | |
| 204 | 
             
                            gr.Markdown("### π Model Upload (Optional)")
         | 
| 205 | 
             
                            model_file = gr.File(
         | 
| 206 | 
             
                                label="Upload your own model.pth file",
         | 
| 207 | 
            +
                                file_types=[".pth", ".pt"]
         | 
|  | |
| 208 | 
             
                            )
         | 
| 209 | 
             
                            upload_btn = gr.Button("Load Model", variant="primary")
         | 
| 210 | 
             
                            model_status_display = gr.Textbox(
         | 
| 211 | 
             
                                label="Model Status",
         | 
| 212 | 
             
                                value=model_status,
         | 
| 213 | 
            +
                                interactive=False
         | 
|  | |
| 214 | 
             
                            )
         | 
| 215 |  | 
| 216 | 
             
                    with gr.Column(scale=1):
         | 
|  | |
| 219 | 
             
                            gr.Markdown("### βοΈ Generation Settings")
         | 
| 220 | 
             
                            system_msg = gr.Textbox(
         | 
| 221 | 
             
                                value="You are Usta, a geographical knowledge assistant trained from scratch.", 
         | 
| 222 | 
            +
                                label="System message"
         | 
|  | |
| 223 | 
             
                            )
         | 
| 224 | 
             
                            max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
         | 
| 225 | 
             
                            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
         | 
|  | |
| 228 | 
             
                                maximum=1.0,
         | 
| 229 | 
             
                                value=0.95,
         | 
| 230 | 
             
                                step=0.05,
         | 
| 231 | 
            +
                                label="Top-p (nucleus sampling)"
         | 
|  | |
| 232 | 
             
                            )
         | 
| 233 |  | 
| 234 | 
             
                # Chat interface
         | 
