Created
July 8, 2024 09:25
-
-
Save EzhilAdhavan/a816aaeb73d9096358a9b6d5a5f6c09c to your computer and use it in GitHub Desktop.
Load the model and tokenizer from the local path
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from transformers import pipeline | |
| import os | |
| app = FastAPI() | |
| local_model_path = "./models/bertweet_sentiment/" | |
| if not os.path.exists(local_model_path): | |
| raise FileNotFoundError(f"The directory '{local_model_path}' does not exist. Ensure the model and tokenizer are saved correctly.") | |
| try: | |
| model = AutoModelForSequenceClassification.from_pretrained(local_model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(local_model_path) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model or tokenizer from '{local_model_path}': {e}") | |
| classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, top_k=None) | |
| class TextInput(BaseModel): | |
| text: str | |
| @app.post("/predict") | |
| async def predict(input: TextInput): | |
| predictions = classifier(input.text) | |
| return predictions | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment