Spaces:
Paused
Paused
tmzh
commited on
Commit
·
ebf567b
1
Parent(s):
113ecf4
add natural language queries
Browse files- agent.py +46 -9
- app.py +12 -0
- functions.py +5 -39
- templates/base.html +1 -0
- templates/query.html +23 -0
agent.py
CHANGED
@@ -1,13 +1,26 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
from tools import tools
|
4 |
from transformers import (
|
5 |
AutoModelForCausalLM,
|
6 |
AutoTokenizer,
|
7 |
-
BitsAndBytesConfig
|
8 |
-
TextIteratorStreamer,
|
9 |
)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
12 |
|
13 |
# specify how to quantize the model
|
@@ -23,11 +36,35 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
23 |
)
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
|
5 |
+
import functools
|
6 |
+
import functions
|
7 |
import torch
|
8 |
from tools import tools
|
9 |
from transformers import (
|
10 |
AutoModelForCausalLM,
|
11 |
AutoTokenizer,
|
12 |
+
BitsAndBytesConfig
|
|
|
13 |
)
|
14 |
|
15 |
+
# Get all functions from functions.py
|
16 |
+
all_functions = [func for func in dir(functions) if callable(
|
17 |
+
getattr(functions, func)) and not func.startswith("__")]
|
18 |
+
|
19 |
+
# Create names_to_function dict containing partials for all functions in functions.py
|
20 |
+
names_to_functions = {func: functools.partial(
|
21 |
+
getattr(functions, func)) for func in all_functions}
|
22 |
+
|
23 |
+
|
24 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
25 |
|
26 |
# specify how to quantize the model
|
|
|
36 |
)
|
37 |
|
38 |
|
39 |
+
def extract_function_call(output):
|
40 |
+
match = re.search(r'<\|python_tag\|>(.*)<\|eom_id\|>', output)
|
41 |
+
if match:
|
42 |
+
function_call = match.group(1)
|
43 |
+
return json.loads(function_call)
|
44 |
+
else:
|
45 |
+
return None
|
46 |
+
|
47 |
+
|
48 |
+
def chatbot(query):
|
49 |
+
messages = [
|
50 |
+
{"role": "system", "content": "You are a movie search assistant bot who uses TMDB to help users find movies. Think step by step and identify the sequence of function calls that will help to answer."},
|
51 |
+
{"role": "user", "content": query},
|
52 |
+
]
|
53 |
+
|
54 |
+
tokenized_chat = tokenizer.apply_chat_template(
|
55 |
+
messages, tools=tools, add_generation_prompt=True, tokenize=True, return_tensors="pt")
|
56 |
|
57 |
+
outputs = model.generate(tokenized_chat, max_new_tokens=128)
|
58 |
+
answer = tokenizer.batch_decode(outputs[:, tokenized_chat.shape[1]:])[0]
|
59 |
+
tool_call = extract_function_call(answer)
|
60 |
|
61 |
+
if tool_call:
|
62 |
+
function_name = tool_call['name']
|
63 |
+
function_params = tool_call['parameters']
|
64 |
+
print("\nfunction_name: ", function_name,
|
65 |
+
"\nfunction_params: ", function_params)
|
66 |
+
function_result = names_to_functions[function_name](**function_params)
|
67 |
+
print(function_result['results'])
|
68 |
+
return function_result['results']
|
69 |
+
else:
|
70 |
+
print("No tool calls found in the answer.")
|
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from flask import Flask, render_template, request, jsonify
|
|
|
2 |
from functions import query_tmdb
|
3 |
|
4 |
app = Flask(__name__)
|
@@ -83,5 +84,16 @@ def cast_filter(cast_id):
|
|
83 |
cast_info=cast_info)
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
if __name__ == '__main__':
|
87 |
app.run(host='0.0.0.0', port=5000)
|
|
|
1 |
from flask import Flask, render_template, request, jsonify
|
2 |
+
from agent import chatbot
|
3 |
from functions import query_tmdb
|
4 |
|
5 |
app = Flask(__name__)
|
|
|
84 |
cast_info=cast_info)
|
85 |
|
86 |
|
87 |
+
@app.route('/query', methods=['GET', 'POST'])
|
88 |
+
def query():
|
89 |
+
if request.method == 'POST':
|
90 |
+
query = request.form.get('query')
|
91 |
+
result = chatbot(query) # Call the chatbot function from agent.py
|
92 |
+
print("Got result: ", result)
|
93 |
+
print("Movie ids are: ", [r['id'] for r in result])
|
94 |
+
return render_template('query.html', result=result)
|
95 |
+
return render_template('query.html')
|
96 |
+
|
97 |
+
|
98 |
if __name__ == '__main__':
|
99 |
app.run(host='0.0.0.0', port=5000)
|
functions.py
CHANGED
@@ -14,47 +14,13 @@ def query_tmdb(endpoint, params={}):
|
|
14 |
return response.json()
|
15 |
|
16 |
|
17 |
-
def discover_movie(
|
18 |
-
certification_country=None, include_adult=False, include_video=False,
|
19 |
-
language="en-US", page=1, primary_release_year=None,
|
20 |
-
primary_release_date_gte=None, primary_release_date_lte=None,
|
21 |
-
region=None, release_date_gte=None, release_date_lte=None,
|
22 |
-
sort_by="popularity.desc", vote_average_gte=None, vote_average_lte=None):
|
23 |
endpoint = f"{BASE_URL}/discover/movie"
|
24 |
-
params = {
|
25 |
-
'include_adult': include_adult,
|
26 |
-
'include_video': include_video,
|
27 |
-
'language': language,
|
28 |
-
'page': page,
|
29 |
-
'sort_by': sort_by
|
30 |
-
}
|
31 |
-
|
32 |
-
# Add optional parameters if they're provided
|
33 |
-
if certification:
|
34 |
-
params['certification'] = certification
|
35 |
-
if certification_gte:
|
36 |
-
params['certification.gte'] = certification_gte
|
37 |
-
if certification_lte:
|
38 |
-
params['certification.lte'] = certification_lte
|
39 |
-
if certification_country:
|
40 |
-
params['certification_country'] = certification_country
|
41 |
-
if primary_release_year:
|
42 |
-
params['primary_release_year'] = primary_release_year
|
43 |
-
if primary_release_date_gte:
|
44 |
-
params['primary_release_date.gte'] = primary_release_date_gte
|
45 |
-
if primary_release_date_lte:
|
46 |
-
params['primary_release_date.lte'] = primary_release_date_lte
|
47 |
-
if region:
|
48 |
-
params['region'] = region
|
49 |
-
if release_date_gte:
|
50 |
-
params['release_date.gte'] = release_date_gte
|
51 |
-
if release_date_lte:
|
52 |
-
params['release_date.lte'] = release_date_lte
|
53 |
-
if vote_average_gte:
|
54 |
-
params['vote_average.gte'] = vote_average_gte
|
55 |
-
if vote_average_lte:
|
56 |
-
params['vote_average.lte'] = vote_average_lte
|
57 |
|
|
|
|
|
|
|
58 |
response = query_tmdb(endpoint, params=params)
|
59 |
return response
|
60 |
|
|
|
14 |
return response.json()
|
15 |
|
16 |
|
17 |
+
def discover_movie(include_adult=False, include_video=False, language="en-US", page=1, sort_by="popularity.desc", **kwargs):
|
|
|
|
|
|
|
|
|
|
|
18 |
endpoint = f"{BASE_URL}/discover/movie"
|
19 |
+
params = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
for key, value in kwargs.items():
|
22 |
+
if value is not None:
|
23 |
+
params[key] = value
|
24 |
response = query_tmdb(endpoint, params=params)
|
25 |
return response
|
26 |
|
templates/base.html
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
<li><a href="/">Popular Movies</a></li>
|
17 |
<li><a href="/genres">By Genre</a></li>
|
18 |
<li><a href="/cast">By Cast</a></li>
|
|
|
19 |
</ul>
|
20 |
</div>
|
21 |
<div id="main-content">
|
|
|
16 |
<li><a href="/">Popular Movies</a></li>
|
17 |
<li><a href="/genres">By Genre</a></li>
|
18 |
<li><a href="/cast">By Cast</a></li>
|
19 |
+
<li><a href="/query">Query</a></li>
|
20 |
</ul>
|
21 |
</div>
|
22 |
<div id="main-content">
|
templates/query.html
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{% extends 'base.html' %}
|
2 |
+
|
3 |
+
{% block content %}
|
4 |
+
<h1>Query</h1>
|
5 |
+
<form action="/query" method="post">
|
6 |
+
<input type="text" name="query" placeholder="Enter your query...">
|
7 |
+
<input type="submit" value="Submit">
|
8 |
+
</form>
|
9 |
+
|
10 |
+
{% if result %}
|
11 |
+
<h2>Query Result</h2>
|
12 |
+
<div>
|
13 |
+
{% for movie in result %}
|
14 |
+
<div class="tile-element" onclick="showDetails({{ movie.id }})">
|
15 |
+
<img src="https://image.tmdb.org/t/p/w500{{ movie.poster_path }}" alt="{{ movie.title }}">
|
16 |
+
<h4>{{ movie.title }}</h4>
|
17 |
+
</div>
|
18 |
+
{% endfor %}
|
19 |
+
</div>
|
20 |
+
{% elif result is defined %}
|
21 |
+
<p>No results found.</p>
|
22 |
+
{% endif %}
|
23 |
+
{% endblock %}
|