Skip to content

Instantly share code, notes, and snippets.

@DN6
Created November 3, 2021 19:03
Show Gist options
  • Select an option

  • Save DN6/315df8a945540a2c14b4fdda9fbe29f9 to your computer and use it in GitHub Desktop.

Select an option

Save DN6/315df8a945540a2c14b4fdda9fbe29f9 to your computer and use it in GitHub Desktop.
Spark NLP + Comet
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp.logging.comet import CometLogger
spark = sparknlp.start()
OUTPUT_LOG_PATH = "./run"
logger = CometLogger()
document = DocumentAssembler().setInputCol("text").setOutputCol("document")
embds = (
UniversalSentenceEncoder.pretrained()
.setInputCols("document")
.setOutputCol("sentence_embeddings")
)
multiClassifier = (
MultiClassifierDLApproach()
.setInputCols("sentence_embeddings")
.setOutputCol("category")
.setLabelColumn("labels")
.setBatchSize(128)
.setLr(1e-3)
.setThreshold(0.5)
.setShufflePerEpoch(False)
.setEnableOutputLogs(True)
.setOutputLogsPath(OUTPUT_LOG_PATH)
.setMaxEpochs(1)
)
logger.monitor(logdir=OUTPUT_LOG_PATH, model=multiClassifier)
trainDataset = spark.createDataFrame(
[("Nice.", ["positive"]), ("That's bad.", ["negative"])],
schema=["text", "labels"],
)
pipeline = Pipeline(stages=[document, embds, multiClassifier])
pipeline.fit(trainDataset)
logger.end()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment