tmzh commited on
Commit
ebf567b
·
1 Parent(s): 113ecf4

add natural language queries

Browse files
Files changed (5) hide show
  1. agent.py +46 -9
  2. app.py +12 -0
  3. functions.py +5 -39
  4. templates/base.html +1 -0
  5. templates/query.html +23 -0
agent.py CHANGED
@@ -1,13 +1,26 @@
1
- import transformers
 
 
 
 
 
2
  import torch
3
  from tools import tools
4
  from transformers import (
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
- BitsAndBytesConfig,
8
- TextIteratorStreamer,
9
  )
10
 
 
 
 
 
 
 
 
 
 
11
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
12
 
13
  # specify how to quantize the model
@@ -23,11 +36,35 @@ model = AutoModelForCausalLM.from_pretrained(
23
  )
24
 
25
 
26
- messages = [
27
- {"role": "system", "content": "You are a movie search assistant bot who uses TMDB to help users find movies. You should respond with movie IDs and natural language text summaries when asked for movie recommendations. You should only provide the movie ID and the summary, nothing else."},
28
- {"role": "user", "content": "Can you recommend a good action movie?"},
29
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
31
 
32
- inputs = tokenizer.apply_chat_template(
33
- messages, tools=tools, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import json
4
+
5
+ import functools
6
+ import functions
7
  import torch
8
  from tools import tools
9
  from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
+ BitsAndBytesConfig
 
13
  )
14
 
15
+ # Get all functions from functions.py
16
+ all_functions = [func for func in dir(functions) if callable(
17
+ getattr(functions, func)) and not func.startswith("__")]
18
+
19
+ # Create names_to_function dict containing partials for all functions in functions.py
20
+ names_to_functions = {func: functools.partial(
21
+ getattr(functions, func)) for func in all_functions}
22
+
23
+
24
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
25
 
26
  # specify how to quantize the model
 
36
  )
37
 
38
 
39
+ def extract_function_call(output):
40
+ match = re.search(r'<\|python_tag\|>(.*)<\|eom_id\|>', output)
41
+ if match:
42
+ function_call = match.group(1)
43
+ return json.loads(function_call)
44
+ else:
45
+ return None
46
+
47
+
48
+ def chatbot(query):
49
+ messages = [
50
+ {"role": "system", "content": "You are a movie search assistant bot who uses TMDB to help users find movies. Think step by step and identify the sequence of function calls that will help to answer."},
51
+ {"role": "user", "content": query},
52
+ ]
53
+
54
+ tokenized_chat = tokenizer.apply_chat_template(
55
+ messages, tools=tools, add_generation_prompt=True, tokenize=True, return_tensors="pt")
56
 
57
+ outputs = model.generate(tokenized_chat, max_new_tokens=128)
58
+ answer = tokenizer.batch_decode(outputs[:, tokenized_chat.shape[1]:])[0]
59
+ tool_call = extract_function_call(answer)
60
 
61
+ if tool_call:
62
+ function_name = tool_call['name']
63
+ function_params = tool_call['parameters']
64
+ print("\nfunction_name: ", function_name,
65
+ "\nfunction_params: ", function_params)
66
+ function_result = names_to_functions[function_name](**function_params)
67
+ print(function_result['results'])
68
+ return function_result['results']
69
+ else:
70
+ print("No tool calls found in the answer.")
app.py CHANGED
@@ -1,4 +1,5 @@
1
  from flask import Flask, render_template, request, jsonify
 
2
  from functions import query_tmdb
3
 
4
  app = Flask(__name__)
@@ -83,5 +84,16 @@ def cast_filter(cast_id):
83
  cast_info=cast_info)
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
86
  if __name__ == '__main__':
87
  app.run(host='0.0.0.0', port=5000)
 
1
  from flask import Flask, render_template, request, jsonify
2
+ from agent import chatbot
3
  from functions import query_tmdb
4
 
5
  app = Flask(__name__)
 
84
  cast_info=cast_info)
85
 
86
 
87
+ @app.route('/query', methods=['GET', 'POST'])
88
+ def query():
89
+ if request.method == 'POST':
90
+ query = request.form.get('query')
91
+ result = chatbot(query) # Call the chatbot function from agent.py
92
+ print("Got result: ", result)
93
+ print("Movie ids are: ", [r['id'] for r in result])
94
+ return render_template('query.html', result=result)
95
+ return render_template('query.html')
96
+
97
+
98
  if __name__ == '__main__':
99
  app.run(host='0.0.0.0', port=5000)
functions.py CHANGED
@@ -14,47 +14,13 @@ def query_tmdb(endpoint, params={}):
14
  return response.json()
15
 
16
 
17
- def discover_movie(certification=None, certification_gte=None, certification_lte=None,
18
- certification_country=None, include_adult=False, include_video=False,
19
- language="en-US", page=1, primary_release_year=None,
20
- primary_release_date_gte=None, primary_release_date_lte=None,
21
- region=None, release_date_gte=None, release_date_lte=None,
22
- sort_by="popularity.desc", vote_average_gte=None, vote_average_lte=None):
23
  endpoint = f"{BASE_URL}/discover/movie"
24
- params = {
25
- 'include_adult': include_adult,
26
- 'include_video': include_video,
27
- 'language': language,
28
- 'page': page,
29
- 'sort_by': sort_by
30
- }
31
-
32
- # Add optional parameters if they're provided
33
- if certification:
34
- params['certification'] = certification
35
- if certification_gte:
36
- params['certification.gte'] = certification_gte
37
- if certification_lte:
38
- params['certification.lte'] = certification_lte
39
- if certification_country:
40
- params['certification_country'] = certification_country
41
- if primary_release_year:
42
- params['primary_release_year'] = primary_release_year
43
- if primary_release_date_gte:
44
- params['primary_release_date.gte'] = primary_release_date_gte
45
- if primary_release_date_lte:
46
- params['primary_release_date.lte'] = primary_release_date_lte
47
- if region:
48
- params['region'] = region
49
- if release_date_gte:
50
- params['release_date.gte'] = release_date_gte
51
- if release_date_lte:
52
- params['release_date.lte'] = release_date_lte
53
- if vote_average_gte:
54
- params['vote_average.gte'] = vote_average_gte
55
- if vote_average_lte:
56
- params['vote_average.lte'] = vote_average_lte
57
 
 
 
 
58
  response = query_tmdb(endpoint, params=params)
59
  return response
60
 
 
14
  return response.json()
15
 
16
 
17
+ def discover_movie(include_adult=False, include_video=False, language="en-US", page=1, sort_by="popularity.desc", **kwargs):
 
 
 
 
 
18
  endpoint = f"{BASE_URL}/discover/movie"
19
+ params = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ for key, value in kwargs.items():
22
+ if value is not None:
23
+ params[key] = value
24
  response = query_tmdb(endpoint, params=params)
25
  return response
26
 
templates/base.html CHANGED
@@ -16,6 +16,7 @@
16
  <li><a href="/">Popular Movies</a></li>
17
  <li><a href="/genres">By Genre</a></li>
18
  <li><a href="/cast">By Cast</a></li>
 
19
  </ul>
20
  </div>
21
  <div id="main-content">
 
16
  <li><a href="/">Popular Movies</a></li>
17
  <li><a href="/genres">By Genre</a></li>
18
  <li><a href="/cast">By Cast</a></li>
19
+ <li><a href="/query">Query</a></li>
20
  </ul>
21
  </div>
22
  <div id="main-content">
templates/query.html ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% extends 'base.html' %}
2
+
3
+ {% block content %}
4
+ <h1>Query</h1>
5
+ <form action="/query" method="post">
6
+ <input type="text" name="query" placeholder="Enter your query...">
7
+ <input type="submit" value="Submit">
8
+ </form>
9
+
10
+ {% if result %}
11
+ <h2>Query Result</h2>
12
+ <div>
13
+ {% for movie in result %}
14
+ <div class="tile-element" onclick="showDetails({{ movie.id }})">
15
+ <img src="https://image.tmdb.org/t/p/w500{{ movie.poster_path }}" alt="{{ movie.title }}">
16
+ <h4>{{ movie.title }}</h4>
17
+ </div>
18
+ {% endfor %}
19
+ </div>
20
+ {% elif result is defined %}
21
+ <p>No results found.</p>
22
+ {% endif %}
23
+ {% endblock %}