## Fetch data from disk

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/81/91/61d69d58a1af1bd81d9ca9d62c90a6de3ab80d77f27c5df65d9a2c1f5626/transformers-4.5.0-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.2MB 9.4MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/08/cd/342e584ee544d044fb573ae697404ce22ede086c9e87ce5960772084cad0/sacremoses-0.0.44.tar.gz (862kB)
[K     |████████████████████████████████| 870kB 41.6MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 42.0MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.44-cp37-none-any.whl size=886084 sha256=8188b

## Data Prep

In [None]:
import pandas as pd
from torch.utils.data import Dataset, IterableDataset, DataLoader, get_worker_info
from transformers import DistilBertTokenizerFast
import math

In [None]:
ALL_LANGS = ["A# .NET","A# (Axiom)","A-0 System","A+","A++","ABAP","ABC","ABC ALGOL","ABLE","ABSET","ABSYS","ACC","Accent","Ace DASL","ACL2","ACT-III","Action!","ActionScript","Ada","Adenine","Agda","Agilent VEE","Agora","AIMMS","Alef","ALF","ALGOL 58","ALGOL 60","ALGOL 68","ALGOL W","Alice","Alma-0","AmbientTalk","Amiga E","AMOS","AMPL","APL","App Inventor for Android's visual block language","AppleScript","Arc","ARexx","Argus","AspectJ","Assembly language","ATS","Ateji PX","AutoHotkey","Autocoder","AutoIt","AutoLISP / Visual LISP","Averest","AWK","Axum","B","Babbage","Bash","BASIC","bc","BCPL","BeanShell","Batch (Windows/Dos)","Bertrand","BETA","Bigwig","Bistro","BitC","BLISS","Blue","Bon","Boo","Boomerang","Bourne shell","bash","ksh","BREW","BPEL","C","C--","C++","C#","C/AL","Caché ObjectScript","C Shell","Caml","Candle","Cayenne","CDuce","Cecil","Cel","Cesil","Ceylon","CFEngine","CFML","Cg","Ch","Chapel","CHAIN","Charity","Charm","Chef","CHILL","CHIP-8","chomski","ChucK","CICS","Cilk","CL","Claire","Clarion","Clean","Clipper","CLIST","Clojure","CLU","CMS-2","COBOL","Cobra","CODE","CoffeeScript","Cola","ColdC","ColdFusion","COMAL","Combined Programming Language","COMIT","Common Intermediate Language","Common Lisp","COMPASS","Component Pascal","Constraint Handling Rules","Converge","Cool","Coq","Coral 66","Corn","CorVision","COWSEL","CPL","csh","CSP","Csound","CUDA","Curl","Curry","Cyclone","Cython","D","DASL","DASL","Dart","DataFlex","Datalog","DATATRIEVE","dBase","dc","DCL","Deesel","Delphi","DinkC","DIBOL","Dog","Draco","DRAKON","Dylan","DYNAMO","E","E#","Ease","Easy PL/I","Easy Programming Language","EASYTRIEVE PLUS","ECMAScript","Edinburgh IMP","EGL","Eiffel","ELAN","Elixir","Elm","Emacs Lisp","Emerald","Epigram","EPL","Erlang","es","Escapade","Escher","ESPOL","Esterel","Etoys","Euclid","Euler","Euphoria","EusLisp Robot Programming Language","CMS EXEC","EXEC 2","Executable UML","F","F#","Factor","Falcon","Fancy","Fantom","FAUST","Felix","Ferite","FFP","Fjölnir","FL","Flavors","Flex","FLOW-MATIC","FOCAL","FOCUS","FOIL","FORMAC","@Formula","Forth","Fortran","Fortress","FoxBase","FoxPro","FP","FPr","Franz Lisp","Frege","F-Script","FSProg","G","Google Apps Script","Game Maker Language","GameMonkey Script","GAMS","GAP","G-code","Genie","GDL","Gibiane","GJ","GEORGE","GLSL","GNU E","GM","Go","Go!","GOAL","Gödel","Godiva","GOM (Good Old Mad)","Goo","Gosu","GOTRAN","GPSS","GraphTalk","GRASS","Groovy","Hack (programming language)","HAL/S","Hamilton C shell","Harbour","Hartmann pipelines","Haskell","Haxe","High Level Assembly","HLSL","Hop","Hope","Hugo","Hume","HyperTalk","IBM Basic assembly language","IBM HAScript","IBM Informix-4GL","IBM RPG","ICI","Icon","Id","IDL","Idris","IMP","Inform","Io","Ioke","IPL","IPTSCRAE","ISLISP","ISPF","ISWIM","J","J#","J++","JADE","Jako","JAL","Janus","JASS","Java","JavaScript","JCL","JEAN","Join Java","JOSS","Joule","JOVIAL","Joy","JScript","JScript .NET","JavaFX Script","Julia","Jython","K","Kaleidoscope","Karel","Karel++","KEE","Kixtart","KIF","Kojo","Kotlin","KRC","KRL","KUKA","KRYPTON","ksh","L","L# .NET","LabVIEW","Ladder","Lagoona","LANSA","Lasso","LaTeX","Lava","LC-3","Leda","Legoscript","LIL","LilyPond","Limbo","Limnor","LINC","Lingo","Linoleum","LIS","LISA","Lisaac","Lisp","Lite-C","Lithe","Little b","Logo","Logtalk","LPC","LSE","LSL","LiveCode","LiveScript","Lua","Lucid","Lustre","LYaPAS","Lynx","M2001","M4","Machine code","MAD","MAD/I","Magik","Magma","make","Maple","MAPPER","MARK-IV","Mary","MASM Microsoft Assembly x86","Mathematica","MATLAB","Maxima","Macsyma","Max","MaxScript","Maya (MEL)","MDL","Mercury","Mesa","Metacard","Metafont","MetaL","Microcode","MicroScript","MIIS","MillScript","MIMIC","Mirah","Miranda","MIVA Script","ML","Moby","Model 204","Modelica","Modula","Modula-2","Modula-3","Mohol","MOO","Mortran","Mouse","MPD","CIL","MSL","MUMPS","NASM","NATURAL","Napier88","Neko","Nemerle","nesC","NESL","Net.Data","NetLogo","NetRexx","NewLISP","NEWP","Newspeak","NewtonScript","NGL","Nial","Nice","Nickle","Nim","NPL","Not eXactly C","Not Quite C","NSIS","Nu","NWScript","NXT-G","o:XML","Oak","Oberon","Obix","OBJ2","Object Lisp","ObjectLOGO","Object REXX","Object Pascal","Objective-C","Objective-J","Obliq","Obol","OCaml","occam","occam-π","Octave","OmniMark","Onyx","Opa","Opal","OpenCL","OpenEdge ABL","OPL","OPS5","OptimJ","Orc","ORCA/Modula-2","Oriel","Orwell","Oxygene","Oz","P#","ParaSail (programming language)","PARI/GP","Pascal","Pawn","PCASTL","PCF","PEARL","PeopleCode","Perl","PDL","PHP","Phrogram","Pico","Picolisp","Pict","Pike","PIKT","PILOT","Pipelines","Pizza","PL-11","PL/0","PL/B","PL/C","PL/I","PL/M","PL/P","PL/SQL","PL360","PLANC","Plankalkül","Planner","PLEX","PLEXIL","Plus","POP-11","PostScript","PortablE","Powerhouse","PowerBuilder","PowerShell","PPL","Processing","Processing.js","Prograph","PROIV","Prolog","PROMAL","Promela","PROSE modeling language","PROTEL","ProvideX","Pro*C","Pure","Python","Q (equational programming language)","Q (programming language from Kx Systems)","Qalb","QtScript","QuakeC","QPL","R","R++","Racket","RAPID","Rapira","Ratfiv","Ratfor","rc","REBOL","Red","Redcode","REFAL","Reia","Revolution","rex","REXX","Rlab","RobotC","ROOP","RPG","RPL","RSL","RTL/2","Ruby","RuneScript","Rust","S","S2","S3","S-Lang","S-PLUS","SA-C","SabreTalk","SAIL","SALSA","SAM76","SAS","SASL","Sather","Sawzall","SBL","Scala","Scheme","Scilab","Scratch","Script.NET","Sed","Seed7","Self","SenseTalk","SequenceL","SETL","Shift Script","SIMPOL","SIGNAL","SiMPLE","SIMSCRIPT","Simula","Simulink","SISAL","SLIP","SMALL","Smalltalk","Small Basic","SML","Snap!","SNOBOL","SPITBOL","Snowball","SOL","Span","SPARK","Speedcode","SPIN","SP/k","SPS","Squeak","Squirrel","SR","S/SL","Stackless Python","Starlogo","Strand","Stata","Stateflow","Subtext","SuperCollider","SuperTalk","Swift (Apple programming language)","Swift (parallel scripting language)","SYMPL","SyncCharts","SystemVerilog","T","TACL","TACPOL","TADS","TAL","Tcl","Tea","TECO","TELCOMP","TeX","TEX","TIE","Timber","TMG","Tom","TOM","Topspeed","TPU","Trac","TTM","T-SQL","TTCN","Turing","TUTOR","TXL","TypeScript","Turbo C++","Ubercode","UCSD Pascal","Umple","Unicon","Uniface","UNITY","Unix shell","UnrealScript","Vala","VBA","VBScript","Verilog","VHDL","Visual Basic","Visual Basic .NET","Visual DataFlex","Visual DialogScript","Visual Fortran","Visual FoxPro","Visual J++","Visual J#","Visual Objects","Visual Prolog","VSXu","Vvvv","WATFIV, WATFOR","WebDNA","WebQL","Windows PowerShell","Winbatch","Wolfram","Wyvern","X++","X#","X10","XBL","XC","XMOS architecture","xHarbour","XL","Xojo","XOTcl","XPL","XPL0","XQuery","XSB","XSLT","XPath","Xtend","Yorick","YQL","Z notation","Zeno","ZOPL","ZPL"]
ALL_LANGS = list(map(lambda x: x.lower(), ALL_LANGS))
ALL_LANGS_SET = set(ALL_LANGS)
print(ALL_LANGS_SET)

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

{'k', 'gamemonkey script', 'prose modeling language', 'lua', 'x++', 'txl', 'joule', 'xotcl', 'opal', 'lite-c', 'algol 58', 'uniface', 'spin', 'icon', 'promela', 'g', 'chapel', 'nickle', 'nesc', 'boomerang', 'xl', 'l', 'averest', 'b', 'dataflex', 'datatrieve', 'smalltalk', 'mary', 'foxpro', 'prograph', 'moby', 'stackless python', 'assembly language', 'c/al', 'amiga e', 'caml', 'beanshell', 'rpg', 'ada', 'postscript', 'f#', 'snap!', 'iptscrae', 'dbase', 'ferite', 'ceylon', 'visual foxpro', 'redcode', 'quakec', 'mortran', 'kif', 'xpath', 'vbscript', 'goo', 'ace dasl', 'cg', 'lynx', 'sbl', 'yql', 'a+', 'csound', 'dibol', 'abc algol', 'elan', 'q (equational programming language)', 'spitbol', 'lansa', 'rapid', 'csp', 'octave', 'george', 'cil', 'bon', 'cfml', 'pipelines', 'karel++', 'charity', 'argus', 'arc', 'alice', 'a++', 'clu', 'elixir', 'simscript', 'lilypond', 'starlogo', 'pizza', 'pl/0', 'self', 'clojure', 'reia', 'masm microsoft assembly x86', 'maxscript', 'latex', 'ppl', 'tmg', 'fp',

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [None]:
!cp ./drive/MyDrive/howdoi_train.csv ./
!cp ./drive/MyDrive/howdoi_test.csv ./

In [None]:
# do lazy loading with h5py to save memory
'''
import h5py
import numpy as np

import subprocess
train_path, test_path = "./howdoi_train.csv",  "./howdoi_test.csv"
h5_train_path, h5_test_path = "./data_tr.h5", "./data_ts.h5"

# this is just a random large number, this size of data (short strings)
#   doesn't take much RAM, not even sure we have to read it in chunks at all
chunksize = 1000 * 10000

# hacky way of reading the length of the file without opening it
lines_train = subprocess.check_output(['wc', '-l', train_path])
lines_train = int(lines_train.split()[0])

# h5 is a format you can read from without loading up the data in memory
#   so it's perfect for huge datasets

# NOTE: this will take a minute or so
with h5py.File(h5_train_path, 'w') as h5f:
    # use num_features if the csv file has no column header
    texts = h5f.create_dataset("text-train",
                               shape=(lines_train,),
                               compression=None,
                               dtype=h5py.string_dtype('utf-8'))
    labels = h5f.create_dataset("label-train",
                               shape=(lines_train,),
                               compression=None,
                               dtype="bool")

    # read num_lines in chunks of size chunksize
    for i in range(1, lines_train, chunksize):  

        df = pd.read_csv(
          train_path,  
          header=None, # we ignore the header by starting the loop from row 1
          nrows=chunksize,
          skiprows=i
        )
        
        titles = df.values[:, -2]

        # you don't have to do this at this step, you could also just store
        #   this as a string, like in the original csv
        has_tags = [
          len(set(str(t).lower().split('|')).intersection(ALL_LANGS_SET)) > 0
          for t in df.values[:, -1]
        ]
        print(has_tags)

        items_num = len(titles)

        # this fills in the current chunk of the h5 file
        texts[i-1:i-1+items_num] = titles
        labels[i-1:i-1+items_num] = has_tags

# Create test set

'''

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
import h5py
import numpy as np

import subprocess
train_path, test_path = "./howdoi_train.csv",  "./howdoi_test.csv"
h5_train_path, h5_test_path = "./data_tr.h5", "./data_ts.h5"

# this is just a random large number, this size of data (short strings)
#   doesn't take much RAM, not even sure we have to read it in chunks at all
chunksize = 1000 * 10000

# hacky way of reading the length of the file without opening it
lines_train = subprocess.check_output(['wc', '-l', train_path])
lines_train = int(lines_train.split()[0])

df_train = pd.read_csv(train_path)
df_test = pd.read_csv(test_path)

In [None]:
df_train['tags'] = df_train['tags'].map(lambda x:
          len(set(str(x).lower().split('|')).intersection(ALL_LANGS_SET)) > 0
        )
df_test['tags'] = df_test['tags'].map(lambda x:
          len(set(str(x).lower().split('|')).intersection(ALL_LANGS_SET)) > 0
        )

In [None]:
df_test.head()

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,title,tags
0,11971400,11971400,Changing colors of shapes in HTML5 canvas,javascript|html|canvas|polygon
1,5433772,5433772,Where to look for DB file after update-database?,c#|.net|entity
2,8996304,8996304,Graddle missing transitive dependency,maven|gradle|transitive-dependency
3,7648213,7648213,laravel link does work but button does not,twitter-bootstrap|laravel
4,14123938,14123938,Elegant haskell case/error handling in sequent...,haskell


In [None]:
import torch

In [None]:
class QueryDataset(Dataset):
  def __init__(self, filename, kind):
    self.titles = df_train['title']
    self.labels = df_train['tags']

  def __len__(self):
    return self.titles.shape[0]

  def __getitem__(self, i):
    # now the cool bit - read without loading the whole thing in memory!
    title = self.titles[i]
    label = self.labels[i].astype('bool')
    label = 1 if label else 0
    # encoded = tokenizer(title, truncation=True, padding=True)
    out = {'title': title, 'label': label}
    return out

In [None]:
def collate_fn(data):
  titles, labels = [v['title'] for v in data], [v['label'] for v in data]
  encoded = tokenizer(titles, truncation=True, padding=True)
  # for k,v in encoded.items():
  #   print(len(v[0]))
  out = {k: torch.tensor(v) for k,v in encoded.items()}
  out['labels'] = torch.tensor(labels)
  return out

In [None]:
trainset = QueryDataset(h5_train_path, 'train')
trainloader = DataLoader(trainset, batch_size=256, num_workers=2, shuffle=True,
                        collate_fn=collate_fn) # This seemingly redundant collate_fn param actually helps avoid a RuntimeError - https://github.com/pytorch/pytorch/issues/42654#issuecomment-706926806
for i, y in enumerate(trainloader):
  print(y)
  break

{'input_ids': tensor([[  101, 24357,  3746,  ...,     0,     0,     0],
        [  101, 10463,  5164,  ...,     0,     0,     0],
        [  101,  2129,  2000,  ...,     0,     0,     0],
        ...,
        [  101,  9585,  8011,  ...,     0,     0,     0],
        [  101,  2054,  2024,  ...,     0,     0,     0],
        [  101,  2129,  2064,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
        0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0,
        1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,
        1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 

## Model training code

In [None]:
from transformers import DistilBertForSequenceClassification, AdamW

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
model.to(device)
model.train()

optim = AdamW(model.parameters(), lr=5e-5)

for epoch in range(3):
  for batch in trainloader:
    optim.zero_grad()
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs[0]
    loss.backward()
    optim.step()

model.eval()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi