Created
November 3, 2021 19:03
-
-
Save DN6/315df8a945540a2c14b4fdda9fbe29f9 to your computer and use it in GitHub Desktop.
Spark NLP + Comet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sparknlp | |
| from sparknlp.base import * | |
| from sparknlp.annotator import * | |
| from sparknlp.logging.comet import CometLogger | |
| spark = sparknlp.start() | |
| OUTPUT_LOG_PATH = "./run" | |
| logger = CometLogger() | |
| document = DocumentAssembler().setInputCol("text").setOutputCol("document") | |
| embds = ( | |
| UniversalSentenceEncoder.pretrained() | |
| .setInputCols("document") | |
| .setOutputCol("sentence_embeddings") | |
| ) | |
| multiClassifier = ( | |
| MultiClassifierDLApproach() | |
| .setInputCols("sentence_embeddings") | |
| .setOutputCol("category") | |
| .setLabelColumn("labels") | |
| .setBatchSize(128) | |
| .setLr(1e-3) | |
| .setThreshold(0.5) | |
| .setShufflePerEpoch(False) | |
| .setEnableOutputLogs(True) | |
| .setOutputLogsPath(OUTPUT_LOG_PATH) | |
| .setMaxEpochs(1) | |
| ) | |
| logger.monitor(logdir=OUTPUT_LOG_PATH, model=multiClassifier) | |
| trainDataset = spark.createDataFrame( | |
| [("Nice.", ["positive"]), ("That's bad.", ["negative"])], | |
| schema=["text", "labels"], | |
| ) | |
| pipeline = Pipeline(stages=[document, embds, multiClassifier]) | |
| pipeline.fit(trainDataset) | |
| logger.end() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment