Created
March 15, 2023 01:26
-
-
Save s10018/c4f3a52041df1fdccd14a544686230ed to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import dataclasses | |
| from typing import Iterator | |
| @dataclasses.dataclass | |
| class Token: | |
| start: int | |
| end: int | |
| token: str | |
| @dataclasses.dataclass | |
| class Score: | |
| precision: float | |
| recall: float | |
| f1: float | |
| def count_spans(sys_spans: list[Token], gold_spans: list[Token]) -> tuple[int, int, int]: | |
| correct, gi, si = 0, 0, 0 | |
| while gi < len(gold_spans) and si < len(sys_spans): | |
| if sys_spans[si].start < gold_spans[gi].start: | |
| si += 1 | |
| elif gold_spans[gi].start < sys_spans[si].start: | |
| gi += 1 | |
| else: | |
| correct += gold_spans[gi].end == sys_spans[si].end | |
| si, gi = si + 1, gi + 1 | |
| return len(sys_spans), len(gold_spans), correct | |
| def cal_f1(sys_total: int, gold_total: int, correct: int) -> Score: | |
| def division(a: int, b: int) -> float: | |
| return (a / b) if b > 0 else 0.0 | |
| return Score(division(correct, sys_total), division(correct, gold_total), division(2 * correct, (sys_total + gold_total))) | |
| def iterate_data(filename: str) -> Iterator[list[Token]]: | |
| toks: list[Token] = [] | |
| start, end = 0, 0 | |
| with open(filename, "r") as rdr: | |
| for token in rdr: | |
| token= token.rstrip("\n") | |
| if token == "EOS": | |
| yield toks | |
| toks, start, end = [], 0, 0 | |
| continue | |
| end = start + len(token) | |
| toks.append(Token(start, end, token)) | |
| start = end | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='トークンの精度F1を出す') | |
| parser.add_argument('sys_file') | |
| parser.add_argument('gold_file') | |
| args = parser.parse_args() | |
| macro_f, sys_total_s, gold_total_s, correct_s, size = 0.0, 0, 0, 0, 0 | |
| for sys_tok, gold_tok in zip(iterate_data(args.sys_file), iterate_data(args.gold_file)): | |
| if "".join([s.token for s in sys_tok]) != "".join([s.token for s in gold_tok]): | |
| continue | |
| sys_total, gold_total, correct = count_spans(sys_tok, gold_tok) | |
| score = cal_f1(sys_total, gold_total, correct) | |
| sys_total_s, gold_total_s, correct_s = sys_total_s + sys_total, gold_total_s + gold_total, correct_s + correct | |
| macro_f, size = macro_f + score.f1, size + 1 | |
| print(macro_f / size, cal_f1(sys_total_s, gold_total_s, correct_s).f1, sys_total_s, gold_total_s, correct_s) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment