Skip to content

Instantly share code, notes, and snippets.

@aimerneige
Created May 2, 2023 14:25
Show Gist options
  • Select an option

  • Save aimerneige/3bc8b1e85642f0efabb3d2b7bd3129d8 to your computer and use it in GitHub Desktop.

Select an option

Save aimerneige/3bc8b1e85642f0efabb3d2b7bd3129d8 to your computer and use it in GitHub Desktop.

Revisions

  1. aimerneige created this gist May 2, 2023.
    296 changes: 296 additions & 0 deletions baidu_ocr.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,296 @@
    #!/usr/env/bin python3
    # -*- coding: utf-8 -*-

    # 读取系统文件用
    import sys
    # base64 编码
    import base64
    # 网络请求库
    import requests
    # 枚举类型
    from enum import Enum
    # 图像处理
    import cv2
    # PyQt 框架
    from PyQt5 import QtCore
    from PyQt5.QtGui import QPixmap
    from PyQt5.QtCore import QSize, pyqtSlot
    from PyQt5.QtWidgets import QApplication, QDesktopWidget, QMainWindow, QPushButton, QLabel, QTextEdit, QComboBox, QFileDialog, QMessageBox

    # 窗口标题
    window_title = "文本识别"

    API_KEY = "YOUR_API_KEY_HERE"
    SECRET_KEY = "YOUR_SECRET_KEY_HERE"

    TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
    OCR_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/'


    # OCR 语言接口参数封装
    class Language(Enum):
    auto_detect = 0
    CHN_ENG = 1
    ENG = 2
    JAP = 3
    KOR = 4
    FRE = 5
    SPA = 6
    POR = 7
    GER = 8
    ITA = 9
    RUS = 10
    DAN = 11
    DUT = 12
    MAL = 13
    SWE = 14
    IND = 15
    POL = 16
    ROM = 17
    TUR = 18
    GRE = 19
    HUN = 20


    # OCR 调用封装类
    class OCR(object):

    def __init__(self, api_key, secret_key):
    super().__init__()
    self.API_KEY = api_key
    self.SECRET_KEY = secret_key
    self.OCR_TYPE = "通用文字"
    self.ACCESS_TOKEN = self.fetch_token()
    self.LANGUAGE = Language.auto_detect.name
    self.DETECT_DIRECTION = "false"
    self.PARAGRAPH = "false"
    self.PROBABILITY = "true"

    def set_language(self, language):
    self.LANGUAGE = language.name

    def set_detect_direction(self, detect_direction):
    self.DETECT_DIRECTION = detect_direction

    def set_paragraph(self, paragraph):
    self.PARAGRAPH = paragraph

    def set_probability(self, probability):
    self.PROBABILITY = probability

    # 获取 Token
    def fetch_token(self):
    response = requests.post(TOKEN_URL, data={
    'grant_type': 'client_credentials',
    'client_id': self.API_KEY,
    'client_secret': self.SECRET_KEY
    })
    if response:
    return response.json()['access_token']

    # 编码图片文件
    def encode_image(self, image_path):
    with open(image_path, 'rb') as f:
    image_data = f.read()
    return base64.b64encode(image_data)

    # 编码 PDF 文件
    def encode_pdf(self, pdf_path):
    with open(pdf_path, 'rb') as f:
    pdf_data = f.read()
    def accurate_basic(self, image_path):
    ocr_url = OCR_URL + "accurate_basic"
    request_url = ocr_url + '?access_token=' + self.ACCESS_TOKEN
    headers = {'Content-Type': 'application/x-www-form-urlencoded'}
    params = {
    'image': self.encode_image(image_path),
    'language_type': self.LANGUAGE,
    'detect_direction': self.DETECT_DIRECTION,
    'paragraph': self.PARAGRAPH,
    'probability': self.PROBABILITY,
    }
    json_data = requests.post(
    request_url, headers=headers, data=params).json()
    result_text = ""
    for paragraphs in json_data["paragraphs_result"]:
    for index in paragraphs["words_result_idx"]:
    result_text += json_data["words_result"][index]["words"]
    result_text += " "
    result_text += "\n"
    return result_text

    # 数字识别
    def numbers(self, image_path):
    ocr_url = OCR_URL + "numbers"
    request_url = ocr_url + "?access_token=" + self.ACCESS_TOKEN
    headers = {'content-type': 'application/x-www-form-urlencoded'}
    params = {
    'image': self.encode_image(image_path),
    'detect_direction': self.DETECT_DIRECTION,
    }
    json_data = requests.post(
    request_url, headers=headers, data=params).json()
    result_text = ""
    for result in json_data["words_result"]:
    numbers = result['words']
    result_text += numbers
    result_text += "\n"
    return result_text

    # 手写识别
    def handwriting(self, image_path):
    ocr_url = OCR_URL + "handwriting"
    request_url = ocr_url + "?access_token=" + self.ACCESS_TOKEN
    headers = {'content-type': 'application/x-www-form-urlencoded'}
    params = {
    'image': self.encode_image(image_path),
    'detect_direction': self.DETECT_DIRECTION,
    'probability': self.PROBABILITY,
    }
    json_data = requests.post(
    request_url, headers=headers, data=params).json()
    result_text = ""
    for result in json_data["words_result"]:
    words = result['words']
    result_text += words
    result_text += "\n"
    return result_text

    # 获得 OCR 结果
    def get_ocr_result(self, image_path):
    if self.OCR_TYPE == "通用文字":
    return self.accurate_basic(image_path)
    elif self.OCR_TYPE == "数字识别":
    return self.numbers(image_path)
    elif self.OCR_TYPE == "手写文字":
    return self.handwriting(image_path)


    # 窗口类
    class Window(QMainWindow):
    def __init__(self):
    super().__init__()
    self.initWindow()
    self.initUI()
    self.initOCR()
    self.center()

    # 初始化窗口大小
    def initWindow(self):
    self.setWindowTitle(window_title)
    self.setFixedWidth(1280)
    self.setFixedHeight(720)

    # 初始化窗口界面
    def initUI(self):
    self.initOCROptionSelection()
    self.initImageSelection()
    self.initTextResult()

    # 初始化 OCR 类
    def initOCR(self):
    self.OCR = OCR(API_KEY, SECRET_KEY)
    self.OCR.set_detect_direction("false")
    self.OCR.set_language(Language.CHN_ENG)
    self.OCR.set_paragraph("true")
    self.OCR.set_probability("false")

    # OCR 识别类型控件
    def initOCROptionSelection(self):
    self.ocrOptionLabel = QLabel("选择识别类型", self)
    self.ocrOptionLabel.setFixedSize(QSize(220, 40))
    self.ocrOptionLabel.move(80, 80)
    self.ocrOptionSelection = QComboBox(self)
    self.ocrOptionSelection.move(320, 80)
    self.ocrOptionSelection.setFixedSize(QSize(240, 40))
    self.ocrOptionSelection.addItem("通用文字")
    self.ocrOptionSelection.addItem("数字识别")
    self.ocrOptionSelection.addItem("手写文字")
    # 绑定修改事件
    self.ocrOptionSelection.currentIndexChanged.connect(
    self.selectionChange)

    # 图片选择及演示控件
    def initImageSelection(self):
    self.imageSelectionButton = QPushButton("选择需要识别的图片", self)
    self.imageSelectionButton.setFixedSize(QSize(220, 80))
    self.imageSelectionButton.move(80, 180)
    # 绑定点击事件
    self.imageSelectionButton.clicked.connect(self.imageSelectionClicked)
    self.cameraSelectionButton = QPushButton("使用摄像头拍照", self)
    self.cameraSelectionButton.setFixedSize(QSize(220, 80))
    self.cameraSelectionButton.move(340, 180)
    # 绑定点击事件
    self.cameraSelectionButton.clicked.connect(self.cameraSelectionClicked)
    self.imagePreviewImage = QLabel(self)
    self.imagePreviewImage.setText("请选择要识别的图片")
    self.imagePreviewImage.setStyleSheet(
    "QLabel { background-color : gray; color : black; }")
    self.imagePreviewImage.setAlignment(QtCore.Qt.AlignCenter)
    self.imagePreviewImage.setFixedSize(QSize(480, 320))
    self.imagePreviewImage.setScaledContents(True)
    self.imagePreviewImage.move(80, 320)

    # 识别结果控件
    def initTextResult(self):
    self.textResult = QTextEdit(self)
    self.textResult.setFixedSize(QSize(480, 440))
    self.textResult.move(720, 80)
    self.textCopy = QPushButton("复制到剪切板", self)
    self.textCopy.setFixedSize(QSize(480, 80))
    self.textCopy.move(720, 560)
    # 绑定点击事件
    self.textCopy.clicked.connect(self.copyClicked)

    # 窗口居中
    def center(self):
    qr = self.frameGeometry()
    cp = QDesktopWidget().availableGeometry().center()
    qr.moveCenter(cp)
    self.move(qr.topLeft())

    # 识别类型选择
    def selectionChange(self, i):
    self.OCR.OCR_TYPE = self.ocrOptionSelection.currentText()

    # 选择系统图片
    @pyqtSlot()
    def imageSelectionClicked(self):
    selected_file = QFileDialog.getOpenFileName(self, "选择你要识别的图片", "~/")
    file_path = selected_file[0]
    self.imagePreviewImage.setPixmap(QPixmap(file_path))
    self.callOCR(file_path)

    @pyqtSlot()
    def cameraSelectionClicked(self):
    cam = cv2.VideoCapture(0)
    _, img = cam.read()
    cam_img_path = "./camera_temp.png"
    cv2.imwrite(cam_img_path, img)
    self.imagePreviewImage.setPixmap(QPixmap(cam_img_path))
    self.callOCR(cam_img_path)

    # 复制到剪切板
    @pyqtSlot()
    def copyClicked(self):
    print("copy")
    QApplication.clipboard().setText(self.textResult.toPlainText())
    msg = QMessageBox(self)
    msg.setText('已复制到剪切板')
    msg.exec_()

    # 调用 OCR 接口
    def callOCR(self, image_path):
    ocr_result = self.OCR.get_ocr_result(image_path)
    self.textResult.setPlainText(ocr_result)


    def main():
    app = QApplication(sys.argv)
    window = Window()
    window.show()
    sys.exit(app.exec_())


    if __name__ == '__main__':
    main()