|
好的,这是基于我们之前的讨论,并且**假设你已经将Mermaid图成功转换为名为 `conceptual_knowledge_graph.png` 的图片,并打算将其上传到你的Hugging Face仓库根目录**的 `README.md` 完整内容。 |
|
|
|
**请注意:** 你需要将 `[你的训练样本数]`、`[你的训练epoch数]`、`[你的测试样本数]` 这些占位符替换为你的实际数值。 |
|
|
|
```markdown |
|
--- |
|
license: apache-2.0 |
|
language: |
|
- zh |
|
- en |
|
tags: |
|
- trajectory-prediction |
|
- llm |
|
- lora |
|
- gpt2 |
|
- physics-informed |
|
- autonomous-driving |
|
- motion-forecasting |
|
pipeline_tag: text-generation |
|
base_model: gpt2 |
|
--- |
|
|
|
# GPT-2 LoRA for Physics-Informed Trajectory Prediction |
|
|
|
This repository contains a LoRA (Low-Rank Adaptation) adapter fine-tuned on the `gpt2` base model. The goal is to predict physically plausible future trajectories for autonomous driving scenarios, implicitly incorporating simple physical laws through data and training. |
|
|
|
**开发团队/作者:** 天算AI科技研发实验室 (Natural Algorithm AI R&D Lab) |
|
**项目/研究主页 (可选):** [如果适用,请在此处添加链接] |
|
|
|
## 模型描述 |
|
|
|
该模型是一个经过微调的 `gpt2` 版本,通过LoRA技术高效地学习从历史轨迹数据预测未来轨迹。微调的核心思想是让语言模型不仅学习序列模式,还能在一定程度上遵循基本的物理运动规律,如匀速和匀加速运动。 |
|
|
|
## 微调过程简述 |
|
|
|
1. **基础模型:** `gpt2` (来自Hugging Face Transformers)。 |
|
2. **数据集:** |
|
* **类型:** 综合生成的文本格式轨迹数据。 |
|
* **格式:** 每个样本包含一段历史轨迹和对应的未来真实轨迹,表示为 `历史: x1,y1,vx1,vy1; ... 预测: xN,yN,vxN,vyN; ...`。 |
|
* **物理规律:** 数据生成脚本中包含了匀速直线运动和匀加速直线运动模型,确保训练数据在理想情况下符合基础物理。时间步长 `dt` 设置为 0.1秒。 |
|
* **规模:** 使用了约 [你的训练样本数,例如 300] 条样本进行微调演示。 |
|
3. **微调技术:** |
|
* **LoRA (Low-Rank Adaptation):** 主要对`gpt2`模型中的注意力权重 (`c_attn`) 应用LoRA层。 |
|
* **LoRA参数:**秩 (r) = 8, alpha = 16, dropout = 0.05。 |
|
* **训练设置:** 在Google Colab T4 GPU上进行了 [你的训练epoch数,例如 5] 个epoch的训练,批次大小为4,学习率为3e-4。 |
|
4. **目标:** 模型学习根据给定的历史轨迹(包括位置x, y和速度vx, vy)续写生成未来若干时间步的轨迹。 |
|
|
|
## 如何使用 |
|
|
|
下面的代码片段展示了如何加载基础 `gpt2` 模型并应用此LoRA适配器进行推理: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
|
|
# 你的模型在Hugging Face Hub上的ID |
|
adapter_repo_id = "jinv2/gpt2-lora-trajectory-prediction" |
|
base_model_name = "gpt2" |
|
|
|
# 1. 加载基础模型 |
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_name) |
|
|
|
# 2. 加载分词器 (通常与适配器一起保存,或者与基础模型一致) |
|
tokenizer = AutoTokenizer.from_pretrained(adapter_repo_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
# 3. 加载LoRA适配器 |
|
model = PeftModel.from_pretrained(base_model, adapter_repo_id) |
|
|
|
model.eval() # 设置为评估模式 |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
# 4. 准备输入并进行预测 |
|
# 假设历史轨迹有2个点,预测未来2个点 |
|
# dt = 0.1 (与训练时一致) |
|
history_points_str = "1.00,1.00,0.50,0.00; 1.05,1.00,0.50,0.00" # 示例历史 |
|
prompt = f"历史: {history_points_str}; 预测:" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=50, # 足够生成 NUM_FUTURE_POINTS 个点 |
|
num_return_sequences=1, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, # 确保模型知道何时停止 |
|
do_sample=False # 使用贪婪解码进行确定性输出 |
|
) |
|
|
|
generated_text_full = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
predicted_part = "" |
|
if "预测:" in generated_text_full: |
|
predicted_part = generated_text_full.split("预测:")[1].strip() |
|
# 清理可能的末尾分号或eos token的文本残留 |
|
if predicted_part.endswith(tokenizer.eos_token): |
|
predicted_part = predicted_part[:-len(tokenizer.eos_token)].strip() |
|
if predicted_part.endswith(';'): |
|
predicted_part = predicted_part[:-1].strip() |
|
else: |
|
if prompt in generated_text_full: |
|
predicted_part = generated_text_full[len(prompt):].strip() |
|
else: |
|
predicted_part = generated_text_full.strip() |
|
|
|
if predicted_part.endswith(tokenizer.eos_token): |
|
predicted_part = predicted_part[:-len(tokenizer.eos_token)].strip() |
|
if predicted_part.endswith(';'): |
|
predicted_part = predicted_part[:-1].strip() |
|
|
|
|
|
print(f"提示: {prompt}") |
|
print(f"模型预测的未来轨迹点 (文本): {predicted_part}") |
|
|
|
# 示例:解析预测文本的函数 |
|
def parse_trajectory_string(traj_str): |
|
points = [] |
|
if not traj_str or not traj_str.strip(): return points |
|
point_strs = traj_str.strip().split(';') |
|
for p_str in point_strs: |
|
if p_str.strip(): |
|
try: |
|
coords = [float(c.strip()) for c in p_str.split(',')] |
|
if len(coords) == 4: points.append({'x': coords[0], 'y': coords[1], 'vx': coords[2], 'vy': coords[3]}) |
|
except ValueError: print(f"Warning: Could not parse point string: '{p_str}'") |
|
return points |
|
|
|
predicted_points = parse_trajectory_string(predicted_part) |
|
print(f"解析后的预测点: {predicted_points}") |
|
``` |
|
|
|
## 评估结果 |
|
|
|
在包含 [你的测试样本数,例如 50] 条样本的独立测试集上进行了评估。评估指标主要关注预测轨迹的准确性和物理一致性。 |
|
|
|
* **平均位移误差 (Average Displacement Error, ADE):** `0.2684` 米 (平均每个预测时间步与真值的欧氏距离) |
|
* **最终位移误差 (Final Displacement Error, FDE):** `0.2810` 米 (预测轨迹最后一个点与真值的欧氏距离) |
|
|
|
* **运动学一致性 (平均绝对误差):** |
|
* **Vx (X轴速度) 误差:** `0.2844` m/s (模型预测的Vx与根据位移变化推断的Vx之间的差异) |
|
* **Vy (Y轴速度) 误差:** `0.1708` m/s (模型预测的Vy与根据位移变化推断的Vy之间的差异) |
|
* *解读:这些值越小,表明模型在预测速度和预测位置之间的一致性越好。当前的误差值表明还有提升空间,特别是在某些样本中,预测速度与实际位移变化不太吻合。* |
|
|
|
* **动力学约束 (速度限制):** |
|
* **速度限制违反率 (V_max = 2.5 m/s):** `0.00%` |
|
* *解读:模型生成的轨迹点速度值均未超过设定的2.5 m/s的上限,这表明模型学习到了训练数据中的速度范围。* |
|
|
|
* **模型输出观察点:** |
|
* 模型通常能按要求数量生成预测点,但有时会生成略多于请求数量的点,此时需要截取。 |
|
* 在少数情况下,模型输出末尾可能包含非标准格式的文本片段,导致解析警告。这可能与 `max_new_tokens` 设置或序列结束标记的学习有关。 |
|
* 部分样本的ADE/FDE或运动学误差显著高于平均水平,表明模型在特定场景或运动模式下的预测能力有待加强。 |
|
|
|
**总体而言,对于一个基于小型基础模型 (`gpt2`) 并经过短时间、小数据量LoRA微调的实验性模型,其展现了学习轨迹模式和部分物理规律的潜力。** |
|
|
|
## 可视化知识图谱 (概念性) |
|
|
|
为了更好地理解LLM与物理常识模块在无人驾驶轨迹预测中的结合,我们可以构想一个概念性的知识图谱。 |
|
|
|
 |
|
*(上图展示了系统的主要组件和数据流。请确保名为 `conceptual_knowledge_graph.png` 的图片已上传到本仓库的根目录。)* |
|
|
|
## 局限性与未来工作 |
|
|
|
* 当前模型基于简化的2D物理场景,未考虑复杂的车辆动力学、与环境的交互(如避障)。 |
|
* 物理规律的遵循主要依赖于训练数据的质量和覆盖度。 |
|
* 模型输出的稳定性和对特定场景的泛化能力有待提升。 |
|
* **未来工作:** |
|
* 引入更复杂的、可微的物理模块直接集成到模型或损失函数中。 |
|
* 使用更丰富的真实世界数据集进行训练。 |
|
* 探索多模态输入(如图像、LiDAR)与轨迹预测的结合。 |
|
* 对模型的可解释性进行研究。 |
|
|
|
## 版权与许可 |
|
|
|
© 2024 天算AI科技研发实验室 (Natural Algorithm AI R&D Lab). |
|
|
|
本项目根据 Apache License 2.0 许可证授权。详情请参阅仓库中的 `LICENSE` 文件。 |
|
*(请确保在仓库根目录添加一个名为 `LICENSE` 的文件,其中包含 Apache 2.0 许可证的完整文本。)* |
|
|
|
--- |