Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							Β·
						
						b4ac9ca
	
1
								Parent(s):
							
							a5b5003
								
fix: Add user creation during validation
Browse files
    	
        src/distilabel_dataset_generator/apps/sft.py
    CHANGED
    
    | 
         @@ -247,18 +247,6 @@ def push_to_argilla( 
     | 
|
| 247 | 
         
             
                    progress(0.1, desc="Setting up user and workspace")
         
     | 
| 248 | 
         
             
                    client = get_argilla_client()
         
     | 
| 249 | 
         
             
                    hf_user = HfApi().whoami(token=oauth_token.token)["name"]
         
     | 
| 250 | 
         
            -
             
     | 
| 251 | 
         
            -
                    # Create user if it doesn't exist
         
     | 
| 252 | 
         
            -
                    rg_user = client.users(username=hf_user)
         
     | 
| 253 | 
         
            -
                    if rg_user is None:
         
     | 
| 254 | 
         
            -
                        rg_user = client.users.add(rg.User(username=hf_user, role="admin"))
         
     | 
| 255 | 
         
            -
             
     | 
| 256 | 
         
            -
                    # Create workspace if it doesn't exist
         
     | 
| 257 | 
         
            -
                    workspace = client.workspaces(name=rg_user.username)
         
     | 
| 258 | 
         
            -
                    if workspace is None:
         
     | 
| 259 | 
         
            -
                        workspace = client.workspaces.add(rg.Workspace(name=rg_user.username))
         
     | 
| 260 | 
         
            -
                        workspace.add_user(rg_user)
         
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
             
                    if "messages" in dataframe.columns:
         
     | 
| 263 | 
         
             
                        settings = rg.Settings(
         
     | 
| 264 | 
         
             
                            fields=[
         
     | 
| 
         @@ -356,11 +344,11 @@ def push_to_argilla( 
     | 
|
| 356 | 
         
             
                        dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
         
     | 
| 357 | 
         | 
| 358 | 
         
             
                    progress(0.5, desc="Creating dataset")
         
     | 
| 359 | 
         
            -
                    rg_dataset = client.datasets(name=dataset_name, workspace= 
     | 
| 360 | 
         
             
                    if rg_dataset is None:
         
     | 
| 361 | 
         
             
                        rg_dataset = rg.Dataset(
         
     | 
| 362 | 
         
             
                            name=dataset_name,
         
     | 
| 363 | 
         
            -
                            workspace= 
     | 
| 364 | 
         
             
                            settings=settings,
         
     | 
| 365 | 
         
             
                            client=client,
         
     | 
| 366 | 
         
             
                        )
         
     | 
| 
         @@ -386,6 +374,16 @@ def validate_argilla_dataset_name( 
     | 
|
| 386 | 
         
             
                client = get_argilla_client()
         
     | 
| 387 | 
         
             
                if dataset_name is None or dataset_name == "":
         
     | 
| 388 | 
         
             
                    raise gr.Error("Dataset name is required")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 389 | 
         
             
                dataset = client.datasets(name=dataset_name, workspace=hf_user)
         
     | 
| 390 | 
         
             
                if dataset and not add_to_existing_dataset:
         
     | 
| 391 | 
         
             
                    raise gr.Error(f"Dataset {dataset_name} already exists")
         
     | 
| 
         | 
|
| 247 | 
         
             
                    progress(0.1, desc="Setting up user and workspace")
         
     | 
| 248 | 
         
             
                    client = get_argilla_client()
         
     | 
| 249 | 
         
             
                    hf_user = HfApi().whoami(token=oauth_token.token)["name"]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 250 | 
         
             
                    if "messages" in dataframe.columns:
         
     | 
| 251 | 
         
             
                        settings = rg.Settings(
         
     | 
| 252 | 
         
             
                            fields=[
         
     | 
| 
         | 
|
| 344 | 
         
             
                        dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
         
     | 
| 345 | 
         | 
| 346 | 
         
             
                    progress(0.5, desc="Creating dataset")
         
     | 
| 347 | 
         
            +
                    rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
         
     | 
| 348 | 
         
             
                    if rg_dataset is None:
         
     | 
| 349 | 
         
             
                        rg_dataset = rg.Dataset(
         
     | 
| 350 | 
         
             
                            name=dataset_name,
         
     | 
| 351 | 
         
            +
                            workspace=hf_user,
         
     | 
| 352 | 
         
             
                            settings=settings,
         
     | 
| 353 | 
         
             
                            client=client,
         
     | 
| 354 | 
         
             
                        )
         
     | 
| 
         | 
|
| 374 | 
         
             
                client = get_argilla_client()
         
     | 
| 375 | 
         
             
                if dataset_name is None or dataset_name == "":
         
     | 
| 376 | 
         
             
                    raise gr.Error("Dataset name is required")
         
     | 
| 377 | 
         
            +
                # Create user if it doesn't exist
         
     | 
| 378 | 
         
            +
                rg_user = client.users(username=hf_user)
         
     | 
| 379 | 
         
            +
                if rg_user is None:
         
     | 
| 380 | 
         
            +
                    rg_user = client.users.add(rg.User(username=hf_user, role="admin"))
         
     | 
| 381 | 
         
            +
                # Create workspace if it doesn't exist
         
     | 
| 382 | 
         
            +
                workspace = client.workspaces(name=hf_user)
         
     | 
| 383 | 
         
            +
                if workspace is None:
         
     | 
| 384 | 
         
            +
                    workspace = client.workspaces.add(rg.Workspace(name=hf_user))
         
     | 
| 385 | 
         
            +
                    workspace.add_user(hf_user)
         
     | 
| 386 | 
         
            +
                # Check if dataset exists
         
     | 
| 387 | 
         
             
                dataset = client.datasets(name=dataset_name, workspace=hf_user)
         
     | 
| 388 | 
         
             
                if dataset and not add_to_existing_dataset:
         
     | 
| 389 | 
         
             
                    raise gr.Error(f"Dataset {dataset_name} already exists")
         
     |