support for custom messages field in sharegpt (#1651)
Browse files
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
|
@@ -86,6 +86,8 @@ def build_loader(
|
|
| 86 |
)
|
| 87 |
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
| 88 |
strategy.strict = ds_cfg["strict"]
|
|
|
|
|
|
|
| 89 |
return strategy
|
| 90 |
|
| 91 |
return _load
|
|
@@ -97,6 +99,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
| 97 |
"""
|
| 98 |
|
| 99 |
_strict = False
|
|
|
|
| 100 |
|
| 101 |
@property
|
| 102 |
def strict(self):
|
|
@@ -106,8 +109,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
| 106 |
def strict(self, strict):
|
| 107 |
self._strict = strict
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def get_conversation_thread(self, prompt):
|
| 110 |
-
conversations = prompt[
|
| 111 |
if self.strict:
|
| 112 |
return conversations
|
| 113 |
role_key = "from"
|
|
|
|
| 86 |
)
|
| 87 |
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
| 88 |
strategy.strict = ds_cfg["strict"]
|
| 89 |
+
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
| 90 |
+
strategy.messages = ds_cfg["field_messages"]
|
| 91 |
return strategy
|
| 92 |
|
| 93 |
return _load
|
|
|
|
| 99 |
"""
|
| 100 |
|
| 101 |
_strict = False
|
| 102 |
+
_messages = "conversations"
|
| 103 |
|
| 104 |
@property
|
| 105 |
def strict(self):
|
|
|
|
| 109 |
def strict(self, strict):
|
| 110 |
self._strict = strict
|
| 111 |
|
| 112 |
+
@property
|
| 113 |
+
def messages(self):
|
| 114 |
+
return self._messages
|
| 115 |
+
|
| 116 |
+
@messages.setter
|
| 117 |
+
def messages(self, messages):
|
| 118 |
+
self._messages = messages
|
| 119 |
+
|
| 120 |
def get_conversation_thread(self, prompt):
|
| 121 |
+
conversations = prompt[self.messages]
|
| 122 |
if self.strict:
|
| 123 |
return conversations
|
| 124 |
role_key = "from"
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -109,6 +109,7 @@ class SFTDataset(BaseModel):
|
|
| 109 |
field: Optional[str] = None
|
| 110 |
field_human: Optional[str] = None
|
| 111 |
field_model: Optional[str] = None
|
|
|
|
| 112 |
|
| 113 |
roles: Optional[Dict[str, List[str]]] = None
|
| 114 |
|
|
|
|
| 109 |
field: Optional[str] = None
|
| 110 |
field_human: Optional[str] = None
|
| 111 |
field_model: Optional[str] = None
|
| 112 |
+
field_messages: Optional[str] = None
|
| 113 |
|
| 114 |
roles: Optional[Dict[str, List[str]]] = None
|
| 115 |
|