Skip to content

Instantly share code, notes, and snippets.

@prratek
Created November 30, 2018 20:12
Show Gist options
  • Select an option

  • Save prratek/8a5bd55aeea7becab976924542d6c0fb to your computer and use it in GitHub Desktop.

Select an option

Save prratek/8a5bd55aeea7becab976924542d6c0fb to your computer and use it in GitHub Desktop.
Modified serve file for Zeit deployment
from starlette.applications import Starlette
from starlette.responses import HTMLResponse, JSONResponse
from starlette.staticfiles import StaticFiles
from starlette.middleware.cors import CORSMiddleware
import uvicorn, aiohttp, asyncio
from io import BytesIO
from fastai import *
from fastai.vision import *
model_file_url = 'https://drive.google.com/uc?export=download&id=1a5V6Nwg_lUSPPGinWR9CWhDHYcMBiDCQ'
model_file_name = 'model'
data_bunch_url = 'https://drive.google.com/uc?export=download&id=1CQlwniokiti_IS4hOqWu3iaWknkOBr5n'
data_bunch_name = 'export'
classes = ['realist', 'surrealist', 'pop', 'baroque', 'impressionist', 'cubist']
path = Path(__file__).parent
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_headers=['X-Requested-With', 'Content-Type'])
app.mount('/static', StaticFiles(directory='app/static'))
async def download_file(url, dest):
if dest.exists(): return
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
data = await response.read()
with open(dest, 'wb') as f: f.write(data)
async def setup_learner():
await download_file(model_file_url, path/'models'/f'{model_file_name}.pth')
await download_file(data_bunch_url, path/f'{data_bunch_name}.pkl')
data_bunch = ImageDataBunch.load_empty(path, tfms=get_transforms(), size=224).normalize(imagenet_stats)
learn = create_cnn(data_bunch, models.resnet34, pretrained=False)
learn.load(model_file_name)
return learn
loop = asyncio.get_event_loop()
tasks = [asyncio.ensure_future(setup_learner())]
learn = loop.run_until_complete(asyncio.gather(*tasks))[0]
loop.close()
@app.route('/')
def index(request):
html = path/'view'/'index.html'
return HTMLResponse(html.open().read())
@app.route('/analyze', methods=['POST'])
async def analyze(request):
data = await request.form()
img_bytes = await (data['file'].read())
img = open_image(BytesIO(img_bytes))
return JSONResponse({'result': learn.predict(img)[0]})
if __name__ == '__main__':
if 'serve' in sys.argv: uvicorn.run(app, host='0.0.0.0', port=5042)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment