MTMSN 논문 정리
MTMSN 논문 정리
결론적으로, MTMSN 모델은 Drop 데이터의 특성에 맞추어 모든 문제 유형을 커버할 수 있는 구조로 작성되었다는 느낌을 받았음. QANet model이 대체 뭘 말해준 것인지는 모르지만, 이를 기반으로 BERT encoder의 마지막 layer단의 output을 가지고 모든 answer type의 답을 유추하는 모델을 만들었다는 것은 꽤나 흥미로웠다.
mwp에서 해야하는 문제로써의 적절성은 띄지 못한 것 같다. 직접적으로는 equation을 예측하여야 하는 내용들이 나오지 않았고, 간접적으로 common sense를 부여할 수 있는 방법에 대해 insight를 얻을 수 있지 않을까 했는데, 그런 부분에서는 encoder의 마지막 layer output을 적절히 사용하는 것이 가장 큰 수확이었다고 생각들었다.
1. Introduction
이전 SQUAD 혹은 CNN/DM과 같은 데이터는 정답이 단순하여 꽤나 성능이 좋았지만, DROP이 나오게 되며 모델은 다양한 값에 대한 예측을 하여야 했다.
특히 여러 개의 span이 존재하거나, equation을 풀거나, negation을 하는 문제들이 존재하였다. MTMSN은 bert를 맨 위에 두고, 각각의 sub task를 처리하기 위해 모듈을 둔다. 문제 정답의 갯수를 예측하고, 겹치지 않는 정답을 유추한다. 또한 수학 표현 reranking하는 것을 beam search를 사용하여 예측을 확신하게끔 만들어 준다.
2. Task Description
그.. 문제 또 보면 답이 여러 줄로도 나올 수도 있음. answer type에 따라 task를 4개까지 종합하는 모습을 보여줌.
3. Our Approach
Dua et al.(2019) 에서 소개된 passage-question pair 형태를 사용하여 개별 예측 전략으로 삼는다. 나머지는 문제 유형에 대한 전략을 다시 한번 상기시켜주며, weakly-supervised 한 상태에서 학습이 진행된다고 한다. ( 모든 부분을 이해하고 다시 이해하는 과정이 필요할 것으로 보임. )
3.1 Bert-based Encoder
문제와 지문을 wordPiece로 tokenize하고, cls로 나누어 준다. 가장 앞과 뒤에 sep이 들어감. [cls] [q-tokens] [sep] [p-tokens] [sep] 형태임. 해당 input 형태를 bert Encoder로 넣어주고 출력을 봄.
3.2 Multi-Type Answer Predictor
-
Answer type prediction
Augmented QANet에 기인하여 H representation 마지막 4개를 사용하여 answer predictor로써 사용함. 각각을 M0~3로 둠. (대체 QANet은 뭔 논문이길래 저런걸 주장함/??;;) 정답 유형을 예측하기 위해 [sep] 토큰을 기준으로 M2를 Q2와 P2로 분리함. 이후 각각으로부터 hQ2, hP2가 도출될 수 있도록 network를 구성해줌. Ptype은 M3의 최종 representation의 첫번째 vector라고 함. (그래서 cls 인듯) 그를 사용해 ffn을 구성하여 타입을 예측… -> pooler에서 따로 cls 토큰만을 사용하도록 slicing함.
-
Span
gating mechanism과 보편적 decoding 전략을 함께하여 지문과 문제에서 시작과 끝은 예측하고 정답을 추출함. (이후 설명 후 다시 확인) 문제 정보를 요약한 gQ0~2 를 각 Q0~2에 대한 FFN 연산과 자신의 곱으로 나타냄. (대충은) span의 시작과 끝은 예측하기 위해 layer 2,1,0 번째의 출력 M0~2를 적절히 사용하여 concat하고, 외적한 내용도 담아서 수식을 완성한다.
-
Arithmetic expression
passage에서 나타나는 수들의 representation u0~N을 M2와 M3의 concatenation으로 구할 수 있다고 함… 구해진 각 u에 할당되는 sign을 구하기 위해 hq2,hp2,hcls를 연결하여 사용함. (하나의 수에 하나의 연산? 코드 보고 조금 더 이해해봐야할 듯)
- Count
count 하는 문제를 mulit-class classification로써 생각함. 위에서 구한 U를 나타내는 hU를 구하고 이를 적절히 사용하여 classification할 수 있게 하는 듯.
- Negation
각 수에 대해 negation operation를 적용할지 말지에 대한 categorical variable 을 정의해 적용시켜준다. logistical probability로 이를 사용한다.
3.3 Multi-Span Extreaction
이것도 classification 문제로 바라보고 span의 개수를 예측하도록 한다. 중복되지 않는 span을 구하는 것이 중점인데, 이를 non-maximum suppression 알고리즘을 사용하여 해결한다. (*** 이후 계산식이 서술되지 않아 잘 모르겠음. 코드보고 확인 요망)