|
import uvicorn |
|
from fastapi import FastAPI, Depends |
|
from starlette.responses import RedirectResponse |
|
from starlette.middleware.sessions import SessionMiddleware |
|
from authlib.integrations.starlette_client import OAuth, OAuthError |
|
from fastapi import Request |
|
import os |
|
from starlette.config import Config |
|
import gradio as gr |
|
|
|
app = FastAPI() |
|
|
|
|
|
GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID") |
|
GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") |
|
SECRET_KEY = os.environ.get("SECRET_KEY") |
|
|
|
|
|
config_data = {'TWITTER_CLIENT_ID': GOOGLE_CLIENT_ID, 'TWITTER_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} |
|
starlette_config = Config(environ=config_data) |
|
oauth = OAuth(starlette_config) |
|
oauth.register( |
|
name='twitter', |
|
api_base_url='https://api.twitter.com/1.1/', |
|
request_token_url='https://api.twitter.com/oauth/request_token', |
|
access_token_url='https://api.twitter.com/oauth/access_token', |
|
authorize_url='https://api.twitter.com/oauth/authenticate', |
|
) |
|
|
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
|
|
|
|
def get_user(request: Request): |
|
user = request.session.get('user') |
|
if user: |
|
return user['name'] |
|
return None |
|
|
|
@app.get('/') |
|
def public(request: Request, user = Depends(get_user)): |
|
root_url = gr.route_utils.get_root_url(request, "/", None) |
|
if user: |
|
return RedirectResponse(url=f'{root_url}/gradio/') |
|
else: |
|
return RedirectResponse(url=f'{root_url}/main/') |
|
|
|
@app.route('/logout') |
|
async def logout(request: Request): |
|
request.session.pop('user', None) |
|
return RedirectResponse(url='/') |
|
|
|
@app.route('/login') |
|
async def login(request: Request): |
|
print("AUTH", request.url_for('auth')) |
|
return await oauth.twitter.authorize_redirect(request, request.url_for('auth')) |
|
|
|
|
|
@app.get('/auth') |
|
async def auth(request: Request): |
|
token = await oauth.twitter.authorize_access_token(request) |
|
url = 'account/verify_credentials.json' |
|
resp = await oauth.twitter.get( |
|
url, params={'skip_status': True}, token=token) |
|
if resp.status_code != 200: |
|
return RedirectResponse(url='/') |
|
else: |
|
user = resp.json() |
|
request.session['user'] = dict(user) |
|
return RedirectResponse(url='/gradio') |
|
|
|
with gr.Blocks() as login_demo: |
|
btn = gr.Button("Login") |
|
_js_redirect = """ |
|
() => { |
|
url = '/login' + window.location.search; |
|
window.open(url, '_blank'); |
|
} |
|
""" |
|
btn.click(None, js=_js_redirect) |
|
|
|
app = gr.mount_gradio_app(app, login_demo, path="/main") |
|
|
|
def greet(request: gr.Request): |
|
return f"Welcome to Gradio, {request.username}" |
|
|
|
with gr.Blocks() as main_demo: |
|
m = gr.Markdown("Welcome to Gradio!") |
|
gr.Button("Logout", link="/logout") |
|
main_demo.load(greet, None, m) |
|
|
|
app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user) |
|
|
|
|
|
if __name__ == '__main__': |
|
uvicorn.run(app) |
|
|