Spaces:
Sleeping
Sleeping
import random | |
from copy import deepcopy | |
from datetime import datetime, time, timedelta | |
from typing import Dict, List, Optional, Union | |
from .long_context import ( | |
AUTOMOBILE_EXTENSION, | |
MA_5_EXTENSION, | |
MA_20_EXTENSION, | |
ORDER_DETAIL_EXTENSION, | |
TECHNOLOGY_EXTENSION, | |
TRANSACTION_HISTORY_EXTENSION, | |
WATCH_LIST_EXTENSION, | |
) | |
CURRENT_TIME = datetime(2024, 9, 1, 10, 30) | |
DEFAULT_STATE = { | |
"orders": { | |
12345: { | |
"id": 12345, | |
"order_type": "Buy", | |
"symbol": "AAPL", | |
"price": 210.65, | |
"amount": 10, | |
"status": "Completed", | |
}, | |
12446: { | |
"id": 12446, | |
"order_type": "Sell", | |
"symbol": "GOOG", | |
"price": 2840.56, | |
"amount": 5, | |
"status": "Pending", | |
}, | |
}, | |
"account_info": { | |
"account_id": 12345, | |
"balance": 10000.0, | |
"binding_card": 1974202140965533, | |
}, | |
"authenticated": False, | |
"market_status": "Closed", | |
"order_counter": 12446, | |
"stocks": { | |
"AAPL": { | |
"price": 227.16, | |
"percent_change": 0.17, | |
"volume": 2.552, | |
"MA(5)": 227.11, | |
"MA(20)": 227.09, | |
}, | |
"GOOG": { | |
"price": 2840.34, | |
"percent_change": 0.24, | |
"volume": 1.123, | |
"MA(5)": 2835.67, | |
"MA(20)": 2842.15, | |
}, | |
"TSLA": { | |
"price": 667.92, | |
"percent_change": -0.12, | |
"volume": 1.654, | |
"MA(5)": 671.15, | |
"MA(20)": 668.20, | |
}, | |
"MSFT": { | |
"price": 310.23, | |
"percent_change": 0.09, | |
"volume": 3.234, | |
"MA(5)": 309.88, | |
"MA(20)": 310.11, | |
}, | |
"NVDA": { | |
"price": 220.34, | |
"percent_change": 0.34, | |
"volume": 1.234, | |
"MA(5)": 220.45, | |
"MA(20)": 220.67, | |
}, | |
"ALPH": { | |
"price": 1320.45, | |
"percent_change": -0.08, | |
"volume": 1.567, | |
"MA(5)": 1321.12, | |
"MA(20)": 1325.78, | |
}, | |
"OMEG": { | |
"price": 457.23, | |
"percent_change": 0.12, | |
"volume": 2.345, | |
"MA(5)": 456.78, | |
"MA(20)": 458.12, | |
}, | |
"QUAS": { | |
"price": 725.89, | |
"percent_change": -0.03, | |
"volume": 1.789, | |
"MA(5)": 726.45, | |
"MA(20)": 728.00, | |
}, | |
"NEPT": { | |
"price": 88.34, | |
"percent_change": 0.19, | |
"volume": 0.654, | |
"MA(5)": 88.21, | |
"MA(20)": 88.67, | |
}, | |
"SYNX": { | |
"price": 345.67, | |
"percent_change": 0.11, | |
"volume": 2.112, | |
"MA(5)": 345.34, | |
"MA(20)": 346.12, | |
}, | |
"ZETA": { | |
"price": 22.09, | |
"percent_change": -0.05, | |
"volume": 0.789, | |
"MA(5)": 22.12, | |
"MA(20)": 22.34, | |
}, | |
}, | |
"watch_list": ["NVDA"], | |
"transaction_history": [], | |
"random_seed": 1053520, | |
} | |
class TradingBot: | |
""" | |
A class representing a trading bot for executing stock trades and managing a trading account. | |
Attributes: | |
orders (Dict[int, Dict[str, Union[str, float, int]]]): A dictionary of orders for purchasing and selling of stock, keyed by order ID. | |
account_info (Dict[str, Union[int, float]]): Information about the trading account. | |
authenticated (bool): Whether the user is currently authenticated. | |
market_status (str): The current status of the market ('Open' or 'Closed'). | |
order_counter (int): A counter for generating unique order IDs. | |
stocks (Dict[str, Dict[str, Union[float, int]]]): Information about various stocks. | |
watch_list (List[str]): A list of stock symbols being watched. | |
transaction_history (List[Dict[str, Union[str, float, int]]]): A history of trading account related transactions. | |
""" | |
def __init__(self): | |
""" | |
Initialize the TradingBot instance. | |
""" | |
self.orders: Dict[int, Dict[str, Union[str, float, int]]] | |
self.account_info: Dict[str, Union[int, float]] | |
self.authenticated: bool | |
self.market_status: str | |
self.order_counter: int | |
self.stocks: Dict[str, Dict[str, Union[float, int]]] | |
self.watch_list: List[str] | |
self.transaction_history: List[Dict[str, Union[str, float, int]]] | |
self._api_description = "This tool belongs to the trading system, which allows users to trade stocks, manage their account, and view stock information." | |
def _load_scenario(self, scenario: dict, long_context=False) -> None: | |
""" | |
Load a scenario into the TradingBot. | |
Args: | |
scenario (dict): A scenario dictionary containing data to load. | |
""" | |
DEFAULT_STATE_COPY = deepcopy(DEFAULT_STATE) | |
self.orders = scenario.get("orders", DEFAULT_STATE_COPY["orders"]) | |
# Convert all string keys that can be interpreted as integers to integer keys | |
self.orders = { | |
int(k) if isinstance(k, str) and k.isdigit() else k: v | |
for k, v in self.orders.items() | |
} | |
self.account_info = scenario.get("account_info", DEFAULT_STATE_COPY["account_info"]) | |
self.authenticated = scenario.get( | |
"authenticated", DEFAULT_STATE_COPY["authenticated"] | |
) | |
self.market_status = scenario.get( | |
"market_status", DEFAULT_STATE_COPY["market_status"] | |
) | |
self.order_counter = scenario.get( | |
"order_counter", DEFAULT_STATE_COPY["order_counter"] | |
) # Start counter from the next order ID | |
self.stocks = scenario.get("stocks", DEFAULT_STATE_COPY["stocks"]) | |
self.watch_list = scenario.get("watch_list", DEFAULT_STATE_COPY["watch_list"]) | |
self.transaction_history = scenario.get( | |
"transaction_history", DEFAULT_STATE_COPY["transaction_history"] | |
) | |
self.long_context = long_context | |
self._random = random.Random( | |
(scenario.get("random_seed", DEFAULT_STATE_COPY["random_seed"])) | |
) | |
def _generate_transaction_timestamp(self) -> str: | |
""" | |
Generate a timestamp for a transaction. | |
Returns: | |
timestamp (str): A formatted timestamp string. | |
""" | |
# Define the start and end dates for the range | |
start_date = CURRENT_TIME | |
end_date = CURRENT_TIME + timedelta(days=1) | |
start_timestamp = int(start_date.timestamp()) | |
end_timestamp = int(end_date.timestamp()) | |
# Generate a random timestamp within the range | |
random_timestamp = self._random.randint(start_timestamp, end_timestamp) | |
# Convert the random timestamp to a datetime object | |
random_date = datetime.fromtimestamp(random_timestamp) | |
return random_date.strftime("%Y-%m-%d %H:%M:%S") | |
def get_current_time(self) -> Dict[str, str]: | |
""" | |
Get the current time. | |
Returns: | |
current_time (str): Current time in HH:MM AM/PM format. | |
""" | |
return {"current_time": CURRENT_TIME.strftime("%I:%M %p")} | |
def update_market_status(self, current_time_str: str) -> Dict[str, str]: | |
""" | |
Update the market status based on the current time. | |
Args: | |
current_time_str (str): Current time in HH:MM AM/PM format. | |
Returns: | |
status (str): Status of the market. [Enum]: ["Open", "Closed"] | |
""" | |
market_open_time = time(9, 30) # Market opens at 9:30 AM | |
market_close_time = time(16, 0) # Market closes at 4:00 PM | |
current_time = datetime.strptime(current_time_str, "%I:%M %p").time() | |
if market_open_time <= current_time <= market_close_time: | |
self.market_status = "Open" | |
return {"status": "Open"} | |
else: | |
self.market_status = "Closed" | |
return {"status": "Closed"} | |
def get_symbol_by_name(self, name: str) -> Dict[str, str]: | |
""" | |
Get the symbol of a stock by company name. | |
Args: | |
name (str): Name of the company. | |
Returns: | |
symbol (str): Symbol of the stock or "Stock not found" if not available. | |
""" | |
symbol_map = { | |
"Apple": "AAPL", | |
"Google": "GOOG", | |
"Tesla": "TSLA", | |
"Microsoft": "MSFT", | |
"Nvidia": "NVDA", | |
"Zeta Corp": "ZETA", | |
"Alpha Tech": "ALPH", | |
"Omega Industries": "OMEG", | |
"Quasar Ltd.": "QUAS", | |
"Neptune Systems": "NEPT", | |
"Synex Solutions": "SYNX", | |
"Amazon": "AMZN", | |
} | |
return {"symbol": symbol_map.get(name, "Stock not found")} | |
def get_stock_info(self, symbol: str) -> Dict[str, Union[float, int, str]]: | |
""" | |
Get the details of a stock. | |
Args: | |
symbol (str): Symbol that uniquely identifies the stock. | |
Returns: | |
price (float): Current price of the stock. | |
percent_change (float): Percentage change in stock price. | |
volume (float): Trading volume of the stock. | |
MA(5) (float): 5-day Moving Average of the stock. | |
MA(20) (float): 20-day Moving Average of the stock. | |
""" | |
if symbol not in self.stocks: | |
return {"error": f"Stock with symbol '{symbol}' not found."} | |
if self.long_context: | |
stock = self.stocks[symbol].copy() | |
stock["MA(5)"] = MA_5_EXTENSION | |
stock["MA(20)"] = MA_20_EXTENSION | |
return stock | |
return self.stocks[symbol] | |
def get_order_details(self, order_id: int) -> Dict[str, Union[str, float, int]]: | |
""" | |
Get the details of an order. | |
Args: | |
order_id (int): ID of the order. | |
Returns: | |
id (int): ID of the order. | |
order_type (str): Type of the order. | |
symbol (str): Symbol of the stock in the order. | |
price (float): Price at which the order was placed. | |
amount (int): Number of shares in the order. | |
status (str): Current status of the order. [Enum]: ["Open", "Pending", "Completed", "Cancelled"] | |
""" | |
if order_id not in self.orders: | |
return { | |
"error": f"Order with ID {order_id} not found." | |
+ "Here is the list of orders_id: " | |
+ str(list(self.orders.keys())) | |
} | |
if self.long_context: | |
order = self.orders[order_id].copy() | |
symbol = order["symbol"] | |
formatted_extension = {} | |
for key, value in ORDER_DETAIL_EXTENSION.items(): | |
try: | |
formatted_extension[key] = value.format(symbol=symbol) | |
except KeyError as e: | |
return {"error": f"KeyError during formatting: {str(e)}"} | |
# Add formatted extension to the order metadata | |
order["metadata"] = formatted_extension | |
return order | |
return self.orders[order_id] | |
def cancel_order(self, order_id: int) -> Dict[str, Union[int, str]]: | |
""" | |
Cancel an order. | |
Args: | |
order_id (int): ID of the order to cancel. | |
Returns: | |
order_id (int): ID of the cancelled order. | |
status (str): New status of the order after cancellation attempt. | |
""" | |
if order_id not in self.orders: | |
return {"error": f"Order with ID {order_id} not found."} | |
if self.orders[order_id]["status"] == "Completed": | |
return {"error": f"Can't cancel order {order_id}. Order is already completed."} | |
self.orders[order_id]["status"] = "Cancelled" | |
return {"order_id": order_id, "status": "Cancelled"} | |
def place_order( | |
self, order_type: str, symbol: str, price: float, amount: int | |
) -> Dict[str, Union[int, str, float]]: | |
""" | |
Place an order. | |
Args: | |
order_type (str): Type of the order (Buy/Sell). | |
symbol (str): Symbol of the stock to trade. | |
price (float): Price at which to place the order. | |
amount (int): Number of shares to trade. | |
Returns: | |
order_id (int): ID of the newly placed order. | |
order_type (str): Type of the order (Buy/Sell). | |
status (str): Initial status of the order. | |
price (float): Price at which the order was placed. | |
amount (int): Number of shares in the order. | |
""" | |
if not self.authenticated: | |
return {"error": "User not authenticated. Please log in to place an order."} | |
if symbol not in self.stocks: | |
return {"error": f"Invalid stock symbol: {symbol}"} | |
if price <= 0 or amount <= 0: | |
return {"error": "Price and amount must be positive values."} | |
price = float(price) | |
order_id = self.order_counter | |
self.orders[order_id] = { | |
"id": order_id, | |
"order_type": order_type, | |
"symbol": symbol, | |
"price": price, | |
"amount": amount, | |
"status": "Open", | |
} | |
self.order_counter += 1 | |
# We return the status as "Pending" to indicate that the order has been placed but not yet executed | |
# When polled later, the status will show as 'Open' | |
# This is to simulate the delay between placing an order and it being executed | |
return { | |
"order_id": order_id, | |
"order_type": order_type, | |
"status": "Pending", | |
"price": price, | |
"amount": amount, | |
} | |
def make_transaction( | |
self, account_id: int, xact_type: str, amount: float | |
) -> Dict[str, Union[str, float]]: | |
""" | |
Make a deposit or withdrawal based on specified amount. | |
Args: | |
account_id (int): ID of the account. | |
xact_type (str): Transaction type (deposit or withdrawal). | |
amount (float): Amount to deposit or withdraw. | |
Returns: | |
status (str): Status of the transaction. | |
new_balance (float): Updated account balance after the transaction. | |
""" | |
if not self.authenticated: | |
return {"error": "User not authenticated. Please log in to make a transaction."} | |
if self.market_status != "Open": | |
return {"error": "Market is closed. Transactions are not allowed."} | |
if account_id != self.account_info["account_id"]: | |
return {"error": f"Account with ID {account_id} not found."} | |
if amount <= 0: | |
return {"error": "Transaction amount must be positive."} | |
if xact_type == "deposit": | |
self.account_info["balance"] += amount | |
self.transaction_history.append( | |
{ | |
"type": "deposit", | |
"amount": amount, | |
"timestamp": self._generate_transaction_timestamp(), | |
} | |
) | |
return { | |
"status": "Deposit successful", | |
"new_balance": self.account_info["balance"], | |
} | |
elif xact_type == "withdrawal": | |
if amount > self.account_info["balance"]: | |
return {"error": "Insufficient funds for withdrawal."} | |
self.account_info["balance"] -= amount | |
self.transaction_history.append( | |
{ | |
"type": "withdrawal", | |
"amount": amount, | |
"timestamp": self._generate_transaction_timestamp(), | |
} | |
) | |
return { | |
"status": "Withdrawal successful", | |
"new_balance": self.account_info["balance"], | |
} | |
return {"error": "Invalid transaction type. Use 'deposit' or 'withdrawal'."} | |
def get_account_info(self) -> Dict[str, Union[int, float]]: | |
""" | |
Get account information. | |
Returns: | |
account_id (int): ID of the account. | |
balance (float): Current balance of the account. | |
binding_card (int): Card number associated with the account. | |
""" | |
if not self.authenticated: | |
return { | |
"error": "User not authenticated. Please log in to view account information." | |
} | |
return self.account_info | |
def trading_login(self, username: str, password: str) -> Dict[str, str]: | |
""" | |
Handle user login. | |
Args: | |
username (str): Username for authentication. | |
password (str): Password for authentication. | |
Returns: | |
status (str): Login status message. | |
""" | |
if self.authenticated: | |
return {"status": "Already logged in"} | |
# In a real system, we would validate the username and password here | |
self.authenticated = True | |
return {"status": "Logged in successfully"} | |
def trading_get_login_status(self) -> Dict[str, bool]: | |
""" | |
Get the login status. | |
Returns: | |
status (bool): Login status. | |
""" | |
return {"status": bool(self.authenticated)} | |
def trading_logout(self) -> Dict[str, str]: | |
""" | |
Handle user logout for trading system. | |
Returns: | |
status (str): Logout status message. | |
""" | |
if not self.authenticated: | |
return {"status": "No user is currently logged in"} | |
self.authenticated = False | |
return {"status": "Logged out successfully"} | |
def fund_account(self, amount: float) -> Dict[str, Union[str, float]]: | |
""" | |
Fund the account with the specified amount. | |
Args: | |
amount (float): Amount to fund the account with. | |
Returns: | |
status (str): Status of the funding operation. | |
new_balance (float): Updated account balance after funding. | |
""" | |
if not self.authenticated: | |
return {"error": "User not authenticated. Please log in to fund the account."} | |
if amount <= 0: | |
return {"error": "Funding amount must be positive."} | |
self.account_info["balance"] += amount | |
self.transaction_history.append( | |
{"type": "deposit", "amount": amount, "timestamp": self._generate_transaction_timestamp()} | |
) | |
return { | |
"status": "Account funded successfully", | |
"new_balance": self.account_info["balance"], | |
} | |
def remove_stock_from_watchlist(self, symbol: str) -> Dict[str, str]: | |
""" | |
Remove a stock from the watchlist. | |
Args: | |
symbol (str): Symbol of the stock to remove. | |
Returns: | |
status (str): Status of the removal operation. | |
""" | |
if not self.authenticated: | |
return { | |
"error": "User not authenticated. Please log in to modify the watchlist." | |
} | |
if symbol not in self.watch_list: | |
return {"error": f"Stock {symbol} not found in watchlist."} | |
self.watch_list.remove(symbol) | |
return {"status": f"Stock {symbol} removed from watchlist successfully."} | |
def get_watchlist(self) -> Dict[str, List[str]]: | |
""" | |
Get the watchlist. | |
Returns: | |
watchlist (List[str]): List of stock symbols in the watchlist. | |
""" | |
if not self.authenticated: | |
return ["Error: User not authenticated. Please log in to view the watchlist."] | |
if self.long_context: | |
watch_list = self.watch_list.copy() | |
watch_list.extend(WATCH_LIST_EXTENSION) | |
return watch_list | |
return {"watchlist": self.watch_list} | |
def get_order_history(self) -> Dict[str, List[Dict[str, Union[str, int, float]]]]: | |
""" | |
Get the stock order ID history. | |
Returns: | |
order_history (List[int]): List of orders ID in the order history. | |
""" | |
if not self.authenticated: | |
return [ | |
{ | |
"error": "User not authenticated. Please log in to view order history." | |
} | |
] | |
return {"history": list(self.orders.keys())} | |
def get_transaction_history( | |
self, start_date: Optional[str] = None, end_date: Optional[str] = None | |
) -> Dict[str, List[Dict[str, Union[str, float]]]]: | |
""" | |
Get the transaction history within a specified date range. | |
Args: | |
start_date (str): [Optional] Start date for the history (format: 'YYYY-MM-DD'). | |
end_date (str): [Optional] End date for the history (format: 'YYYY-MM-DD'). | |
Returns: | |
transaction_history (List[Dict]): List of transactions within the specified date range. | |
- type (str): Type of transaction. [Enum]: ["deposit", "withdrawal"] | |
- amount (float): Amount involved in the transaction. | |
- timestamp (str): Timestamp of the transaction, formatted as 'YYYY-MM-DD HH:MM:SS'. | |
""" | |
if not self.authenticated: | |
return [ | |
{ | |
"error": "User not authenticated. Please log in to view transaction history." | |
} | |
] | |
if start_date: | |
start = datetime.strptime(start_date, "%Y-%m-%d") | |
else: | |
start = datetime.min | |
if end_date: | |
end = datetime.strptime(end_date, "%Y-%m-%d") | |
else: | |
end = datetime.max | |
filtered_history = [ | |
transaction | |
for transaction in self.transaction_history | |
if start | |
<= datetime.strptime(transaction["timestamp"], "%Y-%m-%d %H:%M:%S") | |
<= end | |
] | |
if self.long_context: | |
filtered_history.extend(TRANSACTION_HISTORY_EXTENSION) | |
return {"transaction_history": filtered_history} | |
def update_stock_price( | |
self, symbol: str, new_price: float | |
) -> Dict[str, Union[str, float]]: | |
""" | |
Update the price of a stock. | |
Args: | |
symbol (str): Symbol of the stock to update. | |
new_price (float): New price of the stock. | |
Returns: | |
symbol (str): Symbol of the updated stock. | |
old_price (float): Previous price of the stock. | |
new_price (float): Updated price of the stock. | |
""" | |
if symbol not in self.stocks: | |
return {"error": f"Stock with symbol '{symbol}' not found."} | |
if new_price <= 0: | |
return {"error": "New price must be a positive value."} | |
old_price = self.stocks[symbol]["price"] | |
self.stocks[symbol]["price"] = new_price | |
self.stocks[symbol]["percent_change"] = ((new_price - old_price) / old_price) * 100 | |
return {"symbol": symbol, "old_price": old_price, "new_price": new_price} | |
# below contains a list of functions to be nested | |
def get_available_stocks(self, sector: str) -> Dict[str, List[str]]: | |
""" | |
Get a list of stock symbols in the given sector. | |
Args: | |
sector (str): The sector to retrieve stocks from (e.g., 'Technology'). | |
Returns: | |
stock_list (List[str]): List of stock symbols in the specified sector. | |
""" | |
sector_map = { | |
"Technology": ["AAPL", "GOOG", "MSFT", "NVDA"], | |
"Automobile": ["TSLA", "F", "GM"], | |
} | |
if self.long_context: | |
sector_map["Technology"].extend(TECHNOLOGY_EXTENSION) | |
sector_map["Automobile"].extend(AUTOMOBILE_EXTENSION) | |
return {"stock_list": sector_map.get(sector, [])} | |
def filter_stocks_by_price( | |
self, stocks: List[str], min_price: float, max_price: float | |
) -> Dict[str, List[str]]: | |
""" | |
Filter stocks based on a price range. | |
Args: | |
stocks (List[str]): List of stock symbols to filter. | |
min_price (float): Minimum stock price. | |
max_price (float): Maximum stock price. | |
Returns: | |
filtered_stocks (List[str]): Filtered list of stock symbols within the price range. | |
""" | |
filtered_stocks = [ | |
symbol | |
for symbol in stocks | |
if self.stocks.get(symbol, {}).get("price", 0) >= min_price | |
and self.stocks.get(symbol, {}).get("price", 0) <= max_price | |
] | |
return {"filtered_stocks": filtered_stocks} | |
def add_to_watchlist(self, stock: str) -> Dict[str, List[str]]: | |
""" | |
Add a stock to the watchlist. | |
Args: | |
stock (str): the stock symbol to add to the watchlist. | |
Returns: | |
symbol (str): the symbol that were successfully added to the watchlist. | |
""" | |
if stock not in self.watch_list: | |
if stock in self.stocks: # Ensure symbol is valid | |
self.watch_list.append(stock) | |
return {"symbol": self.watch_list} | |
def notify_price_change(self, stocks: List[str], threshold: float) -> Dict[str, str]: | |
""" | |
Notify if there is a significant price change in the stocks. | |
Args: | |
stocks (List[str]): List of stock symbols to check. | |
threshold (float): Percentage change threshold to trigger a notification. | |
Returns: | |
notification (str): Notification message about the price changes. | |
""" | |
changed_stocks = [ | |
symbol | |
for symbol in stocks | |
if symbol in self.stocks | |
and abs(self.stocks[symbol]["percent_change"]) >= threshold | |
] | |
if changed_stocks: | |
return {"notification": f"Stocks {', '.join(changed_stocks)} have significant price changes."} | |
else: | |
return {"notification": "No significant price changes in the selected stocks."} | |