Skip to content

Instantly share code, notes, and snippets.

@s10018
Created March 15, 2023 01:26
Show Gist options
  • Select an option

  • Save s10018/c4f3a52041df1fdccd14a544686230ed to your computer and use it in GitHub Desktop.

Select an option

Save s10018/c4f3a52041df1fdccd14a544686230ed to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import argparse
import dataclasses
from typing import Iterator
@dataclasses.dataclass
class Token:
start: int
end: int
token: str
@dataclasses.dataclass
class Score:
precision: float
recall: float
f1: float
def count_spans(sys_spans: list[Token], gold_spans: list[Token]) -> tuple[int, int, int]:
correct, gi, si = 0, 0, 0
while gi < len(gold_spans) and si < len(sys_spans):
if sys_spans[si].start < gold_spans[gi].start:
si += 1
elif gold_spans[gi].start < sys_spans[si].start:
gi += 1
else:
correct += gold_spans[gi].end == sys_spans[si].end
si, gi = si + 1, gi + 1
return len(sys_spans), len(gold_spans), correct
def cal_f1(sys_total: int, gold_total: int, correct: int) -> Score:
def division(a: int, b: int) -> float:
return (a / b) if b > 0 else 0.0
return Score(division(correct, sys_total), division(correct, gold_total), division(2 * correct, (sys_total + gold_total)))
def iterate_data(filename: str) -> Iterator[list[Token]]:
toks: list[Token] = []
start, end = 0, 0
with open(filename, "r") as rdr:
for token in rdr:
token= token.rstrip("\n")
if token == "EOS":
yield toks
toks, start, end = [], 0, 0
continue
end = start + len(token)
toks.append(Token(start, end, token))
start = end
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='トークンの精度F1を出す')
parser.add_argument('sys_file')
parser.add_argument('gold_file')
args = parser.parse_args()
macro_f, sys_total_s, gold_total_s, correct_s, size = 0.0, 0, 0, 0, 0
for sys_tok, gold_tok in zip(iterate_data(args.sys_file), iterate_data(args.gold_file)):
if "".join([s.token for s in sys_tok]) != "".join([s.token for s in gold_tok]):
continue
sys_total, gold_total, correct = count_spans(sys_tok, gold_tok)
score = cal_f1(sys_total, gold_total, correct)
sys_total_s, gold_total_s, correct_s = sys_total_s + sys_total, gold_total_s + gold_total, correct_s + correct
macro_f, size = macro_f + score.f1, size + 1
print(macro_f / size, cal_f1(sys_total_s, gold_total_s, correct_s).f1, sys_total_s, gold_total_s, correct_s)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment