Rishi-19 commited on
Commit
73ccf8c
·
verified ·
1 Parent(s): 7f0eed4

changed python example

Browse files
Files changed (1) hide show
  1. README.md +21 -4
README.md CHANGED
@@ -15,15 +15,32 @@ This model is optimized for financial analysis, valuation calculations, and fina
15
  ## Example Usage
16
  ```python
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
- from peft import PeftModel
 
 
 
 
19
 
20
  # Load the model and tokenizer
21
  model_name = "Rishi-19/deepseek_finetuned_model_rishi"
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Generate text
26
- inputs = tokenizer("Calculate the Net Present Value of a project with initial investment of $1M", return_tensors="pt")
27
- outputs = model.generate(**inputs, max_length=200)
 
28
  print(tokenizer.decode(outputs[0]))
29
  ```
 
15
  ## Example Usage
16
  ```python
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from peft import PeftModel, PeftConfig
19
+ import torch
20
+
21
+ # Set device
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
  # Load the model and tokenizer
25
  model_name = "Rishi-19/deepseek_finetuned_model_rishi"
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+
28
+ # Load the base model first
29
+ peft_config = PeftConfig.from_pretrained(model_name)
30
+ base_model = AutoModelForCausalLM.from_pretrained(
31
+ peft_config.base_model_name_or_path,
32
+ torch_dtype=torch.float16, # Use half precision to save memory
33
+ device_map="auto",
34
+ trust_remote_code=True
35
+ )
36
+
37
+ # Then load the PEFT adapter
38
+ model = PeftModel.from_pretrained(base_model, model_name)
39
+ model.eval() # Set to evaluation mode
40
 
41
  # Generate text
42
+ inputs = tokenizer("Calculate the Net Present Value of a project with initial investment of $1M", return_tensors="pt").to(device)
43
+ with torch.no_grad():
44
+ outputs = model.generate(**inputs, max_length=200)
45
  print(tokenizer.decode(outputs[0]))
46
  ```