Aurelien-Morgan-Bot commited on
Commit
3ec7dd1
·
verified ·
1 Parent(s): 4fbd4b8

source-code for model version v0.32_20260221_012419846_UTC- retrain-pipelines 0.1.2

Browse files
v0.32_20260221_012419846_UTC/requirements.txt ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.4.0
2
+ accelerate==1.1.1
3
+ access==1.1.10.post3
4
+ affine==2.4.0
5
+ aiofiles==24.1.0
6
+ aiohappyeyeballs==2.6.1
7
+ aiohttp==3.13.3
8
+ aiosignal==1.4.0
9
+ aiosqlite==0.22.1
10
+ alabaster==1.0.0
11
+ albucore==0.0.24
12
+ albumentations==2.0.8
13
+ ale-py==0.11.2
14
+ alembic==1.18.4
15
+ altair==5.5.0
16
+ annotated-doc==0.0.4
17
+ annotated-types==0.7.0
18
+ antlr4-python3-runtime==4.9.3
19
+ anyio==4.12.1
20
+ anywidget==0.9.21
21
+ apsw==3.51.2.0
22
+ apswutils==0.1.2
23
+ argon2-cffi==25.1.0
24
+ argon2-cffi-bindings==25.1.0
25
+ array_record==0.8.3
26
+ arrow==1.4.0
27
+ arviz==0.22.0
28
+ astropy==7.2.0
29
+ astropy-iers-data==0.2026.2.16.0.48.25
30
+ asttokens==3.0.1
31
+ astunparse==1.6.3
32
+ atpublic==5.1
33
+ attrs==25.4.0
34
+ audioread==3.1.0
35
+ Authlib==1.6.8
36
+ autograd==1.8.0
37
+ babel==2.18.0
38
+ backcall==0.2.0
39
+ beartype==0.22.9
40
+ beautifulsoup4==4.13.5
41
+ betterproto==2.0.0b6
42
+ bigframes==2.35.0
43
+ bigquery-magics==0.10.3
44
+ bitsandbytes==0.44.1
45
+ bleach==6.3.0
46
+ blinker==1.9.0
47
+ blis==1.3.3
48
+ blobfile==3.2.0
49
+ blosc2==4.0.0
50
+ bokeh==3.7.3
51
+ boto3==1.42.53
52
+ botocore==1.42.53
53
+ Bottleneck==1.4.2
54
+ bqplot==0.12.45
55
+ branca==0.8.2
56
+ brotli==1.2.0
57
+ CacheControl==0.14.4
58
+ cachetools==7.0.1
59
+ catalogue==2.0.10
60
+ certifi==2026.1.4
61
+ cffi==2.0.0
62
+ chardet==5.2.0
63
+ charset-normalizer==3.4.4
64
+ clarabel==0.11.1
65
+ click==8.3.1
66
+ click-plugins==1.1.1.2
67
+ cligj==0.7.2
68
+ cloudpathlib==0.23.0
69
+ cloudpickle==3.1.2
70
+ cmake==3.31.10
71
+ cmdstanpy==1.3.0
72
+ colorcet==3.1.0
73
+ colorlover==0.3.0
74
+ colour==0.1.5
75
+ comm==0.2.3
76
+ community==1.0.0b1
77
+ confection==0.1.5
78
+ cons==0.4.7
79
+ contourpy==1.3.3
80
+ cramjam==2.11.0
81
+ cryptography==43.0.3
82
+ cuda-bindings==12.9.4
83
+ cuda-core==0.3.2
84
+ cuda-pathfinder==1.3.4
85
+ cuda-python==12.9.4
86
+ cuda-toolkit==12.8.1
87
+ cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-25.10.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
88
+ cudf-polars-cu12==25.10.0
89
+ cufflinks==0.17.3
90
+ cuml-cu12==25.10.0
91
+ cupy-cuda12x==13.6.0
92
+ curl_cffi==0.14.0
93
+ cvxopt==1.3.2
94
+ cvxpy==1.6.7
95
+ cycler==0.12.1
96
+ cyipopt==1.5.0
97
+ cymem==2.0.13
98
+ Cython==3.0.12
99
+ dask==2025.9.1
100
+ dask-cuda==25.10.0
101
+ dask-cudf-cu12==25.10.0
102
+ dataproc-spark-connect==1.0.2
103
+ datasets==4.0.0
104
+ db-dtypes==1.5.0
105
+ dbus-python==1.2.18
106
+ debugpy==1.8.15
107
+ decorator==5.2.1
108
+ defusedxml==0.7.1
109
+ deprecation==2.1.0
110
+ diffusers==0.36.0
111
+ dill==0.3.8
112
+ distributed==2025.9.1
113
+ distributed-ucxx-cu12==0.46.0
114
+ distro==1.9.0
115
+ dlib==19.24.6
116
+ dm-tree==0.1.9
117
+ docstring_parser==0.17.0
118
+ docutils==0.21.2
119
+ dopamine_rl==4.1.2
120
+ duckdb==1.3.2
121
+ earthengine-api==1.5.24
122
+ easydict==1.13
123
+ editdistance==0.8.1
124
+ eerepr==0.1.2
125
+ einops==0.8.2
126
+ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
127
+ entrypoints==0.4
128
+ esda==2.8.1
129
+ et_xmlfile==2.0.0
130
+ etils==1.13.0
131
+ etuples==0.3.10
132
+ executing==2.2.1
133
+ Farama-Notifications==0.0.4
134
+ fastai==2.8.7
135
+ fastapi==0.129.0
136
+ fastcore==1.12.14
137
+ fastdownload==0.0.7
138
+ fastjsonschema==2.21.2
139
+ fastlite==0.2.4
140
+ fastprogress==1.1.5
141
+ fastrlock==0.8.3
142
+ fasttransform==0.0.2
143
+ ffmpy==1.0.0
144
+ filelock==3.24.2
145
+ fiona==1.10.1
146
+ firebase-admin==6.9.0
147
+ Flask==3.1.2
148
+ flatbuffers==25.12.19
149
+ flax==0.11.2
150
+ folium==0.20.0
151
+ fonttools==4.61.1
152
+ fqdn==1.5.1
153
+ frozendict==2.4.7
154
+ frozenlist==1.8.0
155
+ fsspec==2025.3.0
156
+ future==1.0.0
157
+ gast==0.7.0
158
+ gcsfs==2025.3.0
159
+ GDAL==3.8.4
160
+ gdown==5.2.1
161
+ geemap==0.35.3
162
+ geocoder==1.38.1
163
+ geographiclib==2.1
164
+ geopandas==1.1.2
165
+ geopy==2.4.1
166
+ giddy==2.3.8
167
+ gin-config==0.5.0
168
+ gitdb==4.0.12
169
+ GitPython==3.1.46
170
+ glob2==0.7
171
+ google==3.0.0
172
+ google-adk==1.25.0
173
+ google-ai-generativelanguage==0.6.15
174
+ google-api-core==2.29.0
175
+ google-api-python-client==2.190.0
176
+ google-auth==2.47.0
177
+ google-auth-httplib2==0.3.0
178
+ google-auth-oauthlib==1.2.4
179
+ google-cloud-aiplatform==1.137.0
180
+ google-cloud-appengine-logging==1.8.0
181
+ google-cloud-audit-log==0.4.0
182
+ google-cloud-bigquery==3.40.1
183
+ google-cloud-bigquery-connection==1.20.0
184
+ google-cloud-bigquery-storage==2.36.1
185
+ google-cloud-bigtable==2.35.0
186
+ google-cloud-core==2.5.0
187
+ google-cloud-dataproc==5.24.0
188
+ google-cloud-datastore==2.23.0
189
+ google-cloud-discoveryengine==0.13.12
190
+ google-cloud-firestore==2.23.0
191
+ google-cloud-functions==1.22.0
192
+ google-cloud-iam==2.21.0
193
+ google-cloud-language==2.19.0
194
+ google-cloud-logging==3.13.0
195
+ google-cloud-monitoring==2.29.1
196
+ google-cloud-pubsub==2.35.0
197
+ google-cloud-resource-manager==1.16.0
198
+ google-cloud-secret-manager==2.26.0
199
+ google-cloud-spanner==3.63.0
200
+ google-cloud-speech==2.36.1
201
+ google-cloud-storage==3.9.0
202
+ google-cloud-trace==1.18.0
203
+ google-cloud-translate==3.24.0
204
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
205
+ google-crc32c==1.8.0
206
+ google-genai==1.63.0
207
+ google-generativeai==0.8.6
208
+ google-pasta==0.2.0
209
+ google-resumable-media==2.8.0
210
+ googleapis-common-protos==1.72.0
211
+ googledrivedownloader==1.1.0
212
+ gradio==5.50.0
213
+ gradio_client==1.14.0
214
+ grain==0.2.15
215
+ graphviz==0.21
216
+ greenlet==3.3.1
217
+ groovy==0.1.2
218
+ grpc-google-iam-v1==0.14.3
219
+ grpc-interceptor==0.15.4
220
+ grpcio==1.67.1
221
+ grpcio-health-checking==1.67.1
222
+ grpcio-status==1.71.2
223
+ grpclib==0.4.9
224
+ gspread==6.2.1
225
+ gspread-dataframe==4.0.0
226
+ gym==0.25.2
227
+ gym-notices==0.1.0
228
+ gymnasium==1.2.3
229
+ h11==0.16.0
230
+ h2==4.3.0
231
+ h5netcdf==1.8.1
232
+ h5py==3.15.1
233
+ hdbscan==0.8.41
234
+ hf-xet==1.2.0
235
+ hf_transfer==0.1.9
236
+ highspy==1.13.1
237
+ holidays==0.91
238
+ holoviews==1.22.1
239
+ hpack==4.1.0
240
+ html5lib==1.1
241
+ httpcore==1.0.9
242
+ httpimport==1.4.1
243
+ httplib2==0.31.2
244
+ httptools==0.7.1
245
+ httpx==0.28.1
246
+ httpx-sse==0.4.3
247
+ huggingface-hub==0.27.1
248
+ humanize==4.15.0
249
+ hyperframe==6.1.0
250
+ hyperopt==0.2.7
251
+ ibis-framework==9.5.0
252
+ idna==3.11
253
+ ImageIO==2.37.2
254
+ imageio-ffmpeg==0.6.0
255
+ imagesize==1.4.1
256
+ imbalanced-learn==0.14.1
257
+ immutabledict==4.3.1
258
+ importlib_metadata==8.7.1
259
+ importlib_resources==6.5.2
260
+ imutils==0.5.4
261
+ inequality==1.1.2
262
+ inflect==7.5.0
263
+ iniconfig==2.3.0
264
+ intel-cmplr-lib-ur==2025.3.2
265
+ intel-openmp==2025.3.2
266
+ ipyevents==2.0.4
267
+ ipyfilechooser==0.6.0
268
+ ipykernel==7.2.0
269
+ ipyleaflet==0.20.0
270
+ ipyparallel==8.8.0
271
+ ipython==8.21.0
272
+ ipython-genutils==0.2.0
273
+ ipython-sql==0.5.0
274
+ ipytree==0.2.2
275
+ ipywidgets==7.7.1
276
+ isoduration==20.11.0
277
+ itsdangerous==2.2.0
278
+ jaraco.classes==3.4.0
279
+ jaraco.context==6.1.0
280
+ jaraco.functools==4.4.0
281
+ jax==0.7.2
282
+ jax-cuda12-pjrt==0.7.2
283
+ jax-cuda12-plugin==0.7.2
284
+ jaxlib==0.7.2
285
+ jedi==0.19.2
286
+ jeepney==0.9.0
287
+ jieba==0.42.1
288
+ Jinja2==3.1.6
289
+ jiter==0.13.0
290
+ jmespath==1.1.0
291
+ joblib==1.5.3
292
+ jsonpatch==1.33
293
+ jsonpickle==4.1.1
294
+ jsonpointer==3.0.0
295
+ jsonschema==4.26.0
296
+ jsonschema-specifications==2025.9.1
297
+ jupyter-console==6.6.3
298
+ jupyter-events==0.12.0
299
+ jupyter-leaflet==0.20.0
300
+ jupyter_client==8.8.0
301
+ jupyter_core==5.9.1
302
+ jupyter_kernel_gateway @ git+https://github.com/googlecolab/kernel_gateway@b134e9945df25c2dcb98ade9129399be10788671
303
+ jupyter_server==2.14.0
304
+ jupyter_server_terminals==0.5.4
305
+ jupyterlab_pygments==0.3.0
306
+ jupyterlab_widgets==3.0.16
307
+ jupytext==1.19.1
308
+ kaggle==1.7.4.5
309
+ kagglehub==0.3.13
310
+ keras==3.10.0
311
+ keras-hub==0.21.1
312
+ keras-nlp==0.21.1
313
+ keyring==25.7.0
314
+ keyrings.google-artifactregistry-auth==1.1.2
315
+ kiwisolver==1.4.9
316
+ langchain==1.2.10
317
+ langchain-core==1.2.13
318
+ langgraph==1.0.8
319
+ langgraph-checkpoint==4.0.0
320
+ langgraph-prebuilt==1.0.7
321
+ langgraph-sdk==0.3.6
322
+ langsmith==0.7.3
323
+ lark==1.3.1
324
+ lazy_loader==0.4
325
+ libclang==18.1.1
326
+ libcudf-cu12 @ https://pypi.nvidia.com/libcudf-cu12/libcudf_cu12-25.10.0-py3-none-manylinux_2_28_x86_64.whl
327
+ libcugraph-cu12==25.10.1
328
+ libcuml-cu12==25.10.0
329
+ libkvikio-cu12==25.10.0
330
+ libpysal==4.14.1
331
+ libraft-cu12==25.10.0
332
+ librmm-cu12==25.10.0
333
+ librosa==0.11.0
334
+ libucx-cu12==1.19.0
335
+ libucxx-cu12==0.46.0
336
+ lightgbm==4.6.0
337
+ linkify-it-py==2.0.3
338
+ llvmlite==0.43.0
339
+ locket==1.0.0
340
+ logical-unification==0.4.7
341
+ lxml==6.0.2
342
+ Mako==1.3.10
343
+ mapclassify==2.10.0
344
+ Markdown==3.10.2
345
+ markdown-it-py==4.0.0
346
+ MarkupSafe==3.0.3
347
+ matplotlib==3.10.0
348
+ matplotlib-inline==0.2.1
349
+ matplotlib-venn==1.1.2
350
+ mcp==1.26.0
351
+ mdit-py-plugins==0.5.0
352
+ mdurl==0.1.2
353
+ metaflow==2.19.19
354
+ mgwr==2.2.1
355
+ miniKanren==1.0.5
356
+ missingno==0.5.2
357
+ mistune==3.2.0
358
+ mizani==0.13.5
359
+ mkl==2025.3.1
360
+ ml_dtypes==0.5.4
361
+ mlxtend==0.23.4
362
+ mmh3==5.2.0
363
+ momepy==0.11.0
364
+ more-itertools==10.8.0
365
+ moviepy==1.0.3
366
+ mpmath==1.3.0
367
+ msgpack==1.1.2
368
+ multidict==6.7.1
369
+ multipledispatch==1.0.0
370
+ multiprocess==0.70.16
371
+ multitasking==0.0.12
372
+ murmurhash==1.0.15
373
+ music21==9.9.1
374
+ namex==0.1.0
375
+ narwhals==2.16.0
376
+ natsort==8.4.0
377
+ nbclassic==1.3.3
378
+ nbclient==0.10.4
379
+ nbconvert==7.17.0
380
+ nbformat==5.10.4
381
+ ndindex==1.10.1
382
+ nest-asyncio==1.6.0
383
+ networkx==3.6.1
384
+ nibabel==5.3.3
385
+ nltk==3.9.1
386
+ notebook==6.5.7
387
+ notebook_shim==0.2.4
388
+ numba==0.60.0
389
+ numba-cuda==0.19.2
390
+ numexpr==2.14.1
391
+ numpy==2.4.2
392
+ nvidia-cublas-cu12==12.1.3.1
393
+ nvidia-cuda-cccl-cu12==12.9.27
394
+ nvidia-cuda-cupti-cu12==12.1.105
395
+ nvidia-cuda-nvcc-cu12==12.5.82
396
+ nvidia-cuda-nvrtc-cu12==12.1.105
397
+ nvidia-cuda-runtime-cu12==12.1.105
398
+ nvidia-cudnn-cu12==9.1.0.70
399
+ nvidia-cufft-cu12==11.0.2.54
400
+ nvidia-curand-cu12==10.3.2.106
401
+ nvidia-cusolver-cu12==11.4.5.107
402
+ nvidia-cusparse-cu12==12.1.0.106
403
+ nvidia-ml-py==13.590.48
404
+ nvidia-nccl-cu12==2.21.5
405
+ nvidia-nvjitlink-cu12==12.9.86
406
+ nvidia-nvtx-cu12==12.1.105
407
+ nvtx==0.2.14
408
+ nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-25.10.0-py3-none-any.whl
409
+ oauth2client==4.1.3
410
+ oauthlib==3.3.1
411
+ omegaconf==2.3.0
412
+ onemkl-license==2025.3.1
413
+ openai==2.21.0
414
+ opencv-contrib-python==4.13.0.92
415
+ opencv-python==4.13.0.92
416
+ opencv-python-headless==4.13.0.92
417
+ openpyxl==3.1.5
418
+ opentelemetry-api==1.38.0
419
+ opentelemetry-exporter-gcp-logging==1.11.0a0
420
+ opentelemetry-exporter-gcp-monitoring==1.11.0a0
421
+ opentelemetry-exporter-gcp-trace==1.11.0
422
+ opentelemetry-exporter-otlp-proto-common==1.38.0
423
+ opentelemetry-exporter-otlp-proto-http==1.38.0
424
+ opentelemetry-proto==1.38.0
425
+ opentelemetry-resourcedetector-gcp==1.11.0a0
426
+ opentelemetry-sdk==1.38.0
427
+ opentelemetry-semantic-conventions==0.59b0
428
+ opt_einsum==3.4.0
429
+ optax==0.2.7
430
+ optree==0.18.0
431
+ orbax-checkpoint==0.11.32
432
+ orjson==3.11.7
433
+ ormsgpack==1.12.2
434
+ osqp==1.1.1
435
+ overrides==7.7.0
436
+ packaging==26.0
437
+ pandas==2.2.2
438
+ pandas-datareader==0.10.0
439
+ pandas-gbq==0.30.0
440
+ pandas-stubs==2.2.2.240909
441
+ pandocfilters==1.5.1
442
+ panel==1.8.7
443
+ param==2.3.2
444
+ parso==0.8.6
445
+ parsy==2.2
446
+ partd==1.4.2
447
+ patsy==1.0.2
448
+ peewee==3.19.0
449
+ peft==0.14.0
450
+ pexpect==4.9.0
451
+ pickleshare==0.7.5
452
+ pillow==11.3.0
453
+ pip3-autoremove==2.0.1
454
+ platformdirs==4.9.2
455
+ plotly==5.24.1
456
+ plotnine==0.14.5
457
+ pluggy==1.6.0
458
+ plum-dispatch==2.6.1
459
+ ply==3.11
460
+ pointpats==2.5.2
461
+ polars==1.31.0
462
+ pooch==1.9.0
463
+ portpicker==1.5.2
464
+ preshed==3.0.12
465
+ prettytable==3.17.0
466
+ proglog==0.1.12
467
+ progressbar2==4.5.0
468
+ prometheus_client==0.24.1
469
+ promise==2.3
470
+ prompt_toolkit==3.0.52
471
+ propcache==0.4.1
472
+ prophet==1.3.0
473
+ proto-plus==1.27.1
474
+ protobuf==5.29.6
475
+ psutil==5.9.5
476
+ psycopg2==2.9.11
477
+ psygnal==0.15.1
478
+ ptyprocess==0.7.0
479
+ PuLP==3.3.0
480
+ pure_eval==0.2.3
481
+ py-cpuinfo==9.0.0
482
+ py4j==0.10.9.9
483
+ pyarrow==18.1.0
484
+ pyasn1==0.6.2
485
+ pyasn1_modules==0.4.2
486
+ pycairo==1.29.0
487
+ pycocotools==2.0.11
488
+ pycparser==3.0
489
+ pycryptodomex==3.23.0
490
+ pydantic==2.12.3
491
+ pydantic-settings==2.13.0
492
+ pydantic_core==2.41.4
493
+ pydata-google-auth==1.9.1
494
+ pydot==4.0.1
495
+ pydotplus==2.0.2
496
+ PyDrive2==1.21.3
497
+ pydub==0.25.1
498
+ pyerfa==2.0.1.5
499
+ pygame==2.6.1
500
+ pygit2==1.19.1
501
+ Pygments==2.19.2
502
+ PyGObject==3.48.2
503
+ PyJWT==2.11.0
504
+ pylibcudf-cu12 @ https://pypi.nvidia.com/pylibcudf-cu12/pylibcudf_cu12-25.10.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
505
+ pylibcugraph-cu12==25.10.1
506
+ pylibraft-cu12==25.10.0
507
+ pymc==5.27.1
508
+ pynndescent==0.6.0
509
+ pyogrio==0.12.1
510
+ pyomo==6.9.5
511
+ PyOpenGL==3.1.10
512
+ pyOpenSSL==24.2.1
513
+ pyparsing==3.3.2
514
+ pyperclip==1.11.0
515
+ pyproj==3.7.2
516
+ pysal==25.7
517
+ pyshp==3.0.3
518
+ PySocks==1.7.1
519
+ pyspark==4.0.2
520
+ pytensor==2.37.0
521
+ pytest==8.4.2
522
+ python-apt==0.0.0
523
+ python-box==7.3.2
524
+ python-dateutil==2.9.0.post0
525
+ python-dotenv==1.2.1
526
+ python-fasthtml==0.12.42
527
+ python-json-logger==4.0.0
528
+ python-louvain==0.16
529
+ python-multipart==0.0.22
530
+ python-slugify==8.0.4
531
+ python-snappy==0.7.3
532
+ python-utils==3.9.1
533
+ pytz==2025.2
534
+ pyviz_comms==3.0.6
535
+ PyWavelets==1.9.0
536
+ PyYAML==6.0.3
537
+ pyzmq==26.2.1
538
+ quantecon==0.10.1
539
+ raft-dask-cu12==25.10.0
540
+ rapids-dask-dependency==25.10.0
541
+ rapids-logger==0.1.19
542
+ rasterio==1.5.0
543
+ rasterstats==0.20.0
544
+ ratelim==0.1.6
545
+ referencing==0.37.0
546
+ regex==2025.11.3
547
+ requests==2.32.4
548
+ requests-oauthlib==2.0.0
549
+ requests-toolbelt==1.0.0
550
+ requirements-parser==0.9.0
551
+ # Editable install with no version control (retrain-pipelines==0.0.0)
552
+ -e /content/pkg_src
553
+ rfc3339-validator==0.1.4
554
+ rfc3986-validator==0.1.1
555
+ rfc3987-syntax==1.1.0
556
+ rich==14.3.3
557
+ rmm-cu12==25.10.0
558
+ roman-numerals==4.1.0
559
+ roman-numerals-py==4.1.0
560
+ rpds-py==0.30.0
561
+ rpy2==3.5.17
562
+ rsa==4.9.1
563
+ rtree==1.4.1
564
+ ruff==0.15.1
565
+ s3transfer==0.16.0
566
+ safehttpx==0.1.7
567
+ safetensors==0.7.0
568
+ scikit-image==0.25.2
569
+ scikit-learn==1.6.1
570
+ scipy==1.16.3
571
+ scooby==0.11.0
572
+ scs==3.2.11
573
+ seaborn==0.13.2
574
+ SecretStorage==3.5.0
575
+ segregation==2.5.3
576
+ semantic-version==2.10.0
577
+ Send2Trash==2.1.0
578
+ sentence-transformers==5.2.3
579
+ sentencepiece==0.2.1
580
+ sentry-sdk==2.53.0
581
+ setuptools==80.10.2
582
+ shap==0.50.0
583
+ shapely==2.1.2
584
+ shellingham==1.5.4
585
+ simple-parsing==0.1.8
586
+ simplejson==3.20.2
587
+ simsimd==6.5.13
588
+ six==1.17.0
589
+ sklearn-compat==0.1.5
590
+ sklearn-pandas==2.2.0
591
+ slicer==0.0.8
592
+ smart_open==7.5.0
593
+ smmap==5.0.2
594
+ sniffio==1.3.1
595
+ snowballstemmer==3.0.1
596
+ sortedcontainers==2.4.0
597
+ soundfile==0.13.1
598
+ soupsieve==2.8.3
599
+ soxr==1.0.0
600
+ spacy==3.8.11
601
+ spacy-legacy==3.0.12
602
+ spacy-loggers==1.0.5
603
+ spaghetti==1.7.6
604
+ spanner-graph-notebook==1.1.8
605
+ spglm==1.1.0
606
+ Sphinx==8.2.3
607
+ sphinxcontrib-applehelp==2.0.0
608
+ sphinxcontrib-devhelp==2.0.0
609
+ sphinxcontrib-htmlhelp==2.1.0
610
+ sphinxcontrib-jsmath==1.0.1
611
+ sphinxcontrib-qthelp==2.0.0
612
+ sphinxcontrib-serializinghtml==2.0.0
613
+ spint==1.0.7
614
+ splot==1.1.7
615
+ spopt==0.7.0
616
+ spreg==1.8.5
617
+ SQLAlchemy==2.0.46
618
+ sqlalchemy-spanner==1.17.2
619
+ sqlglot==25.20.2
620
+ sqlite-web==0.7.1
621
+ sqlparse==0.5.5
622
+ srsly==2.5.2
623
+ sse-starlette==3.2.0
624
+ stack-data==0.6.3
625
+ stanio==0.5.1
626
+ starlette==0.52.1
627
+ statsmodels==0.14.6
628
+ stringzilla==4.6.0
629
+ stumpy==1.13.0
630
+ sympy==1.13.1
631
+ tables==3.10.2
632
+ tabulate==0.9.0
633
+ tbb==2022.3.1
634
+ tblib==3.2.2
635
+ tcmlib==1.4.1
636
+ tenacity==9.1.4
637
+ tensorboard==2.19.0
638
+ tensorboard-data-server==0.7.2
639
+ tensorflow==2.19.0
640
+ tensorflow-datasets==4.9.9
641
+ tensorflow-hub==0.16.1
642
+ tensorflow-metadata==1.17.3
643
+ tensorflow-probability==0.25.0
644
+ tensorflow-text==2.19.0
645
+ tensorflow_decision_forests==1.12.0
646
+ tensorstore==0.1.81
647
+ termcolor==3.3.0
648
+ terminado==0.18.1
649
+ text-unidecode==1.3
650
+ textblob==0.19.0
651
+ tf-slim==1.1.0
652
+ tf_keras==2.19.0
653
+ thinc==8.3.10
654
+ threadpoolctl==3.6.0
655
+ tifffile==2026.2.16
656
+ tiktoken==0.12.0
657
+ timm==1.0.24
658
+ tinycss2==1.4.0
659
+ tobler==0.13.0
660
+ tokenizers==0.20.3
661
+ toml==0.10.2
662
+ tomlkit==0.13.3
663
+ toolz==0.12.1
664
+ torch==2.5.1+cu121
665
+ torchaudio==2.5.1+cu121
666
+ torchcodec==0.10.0+cu128
667
+ torchdata==0.11.0
668
+ torchsummary==1.5.1
669
+ torchtune==0.6.1
670
+ torchvision==0.20.1+cu121
671
+ tornado==6.5.1
672
+ tqdm==4.67.3
673
+ traitlets==5.14.3
674
+ traittypes==0.2.3
675
+ transformers==4.46.2
676
+ treelite==4.4.1
677
+ treescope==0.1.10
678
+ triton==3.1.0
679
+ trl==0.12.0
680
+ tsfresh==0.21.1
681
+ tweepy==4.16.0
682
+ typeguard==4.5.0
683
+ typer==0.24.0
684
+ typer-slim==0.24.0
685
+ types-pytz==2025.2.0.20251108
686
+ types-setuptools==80.10.0.20260124
687
+ typing-inspection==0.4.2
688
+ typing_extensions==4.15.0
689
+ tyro==1.0.6
690
+ tzdata==2025.3
691
+ tzlocal==5.3.1
692
+ uc-micro-py==1.0.3
693
+ ucxx-cu12==0.46.0
694
+ umap-learn==0.5.11
695
+ umf==1.0.3
696
+ unsloth @ git+https://github.com/unslothai/unsloth.git@0c8c5ed81e423658ab9ae81eac5aab8d18f5d7af
697
+ unsloth_zoo==2024.11.5
698
+ uri-template==1.3.0
699
+ uritemplate==4.2.0
700
+ urllib3==2.5.0
701
+ uuid_utils==0.14.0
702
+ uvicorn==0.41.0
703
+ uvloop==0.22.1
704
+ vega-datasets==0.9.0
705
+ wandb==0.25.0
706
+ wasabi==1.1.3
707
+ watchdog==6.0.0
708
+ watchfiles==1.1.1
709
+ wcwidth==0.6.0
710
+ weasel==0.4.3
711
+ webcolors==25.10.0
712
+ webencodings==0.5.1
713
+ websocket-client==1.9.0
714
+ websockets==15.0.1
715
+ Werkzeug==3.1.6
716
+ wheel==0.46.3
717
+ widgetsnbextension==3.6.10
718
+ wordcloud==1.9.6
719
+ wrapt==2.1.1
720
+ wsproto==1.3.2
721
+ wurlitzer==3.1.1
722
+ xarray==2025.12.0
723
+ xarray-einstats==0.9.1
724
+ xformers==0.0.29.post1
725
+ xgboost==3.2.0
726
+ xlrd==2.0.2
727
+ xxhash==3.6.0
728
+ xyzservices==2025.11.0
729
+ yarl==1.22.0
730
+ ydf==0.15.0
731
+ yellowbrick==1.5
732
+ yfinance==0.2.66
733
+ zict==3.0.0
734
+ zipp==3.23.0
735
+ zstandard==0.25.0
v0.32_20260221_012419846_UTC/retraining_pipeline.py ADDED
@@ -0,0 +1,2244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from unsloth import FastLanguageModel, \
3
+ is_bfloat16_supported, UnslothTrainer, \
4
+ UnslothTrainingArguments
5
+
6
+ import torch
7
+
8
+ import os
9
+ import gc
10
+ import re
11
+ import sys
12
+ import json
13
+ import time
14
+ import shutil
15
+ import logging
16
+ import builtins
17
+
18
+ import importlib.util
19
+ from enum import Enum
20
+ from textwrap import dedent
21
+ from datetime import datetime, \
22
+ timezone
23
+
24
+ import polars as pl
25
+ from polars.exceptions import ComputeError
26
+
27
+ from jinja2 import Environment, FileSystemLoader
28
+
29
+ from huggingface_hub import list_repo_commits
30
+ from datasets import load_dataset, \
31
+ Dataset, DatasetDict
32
+ from datasets.config import HF_DATASETS_CACHE, \
33
+ HF_CACHE_HOME
34
+ from transformers import AutoTokenizer
35
+
36
+ from retrain_pipelines import __version__
37
+ from retrain_pipelines.dataset.hf_utils import \
38
+ get_lazy_df, get_column_info, \
39
+ iterable_dataset_multi_buffer_sampler, \
40
+ push_dataset_version_to_hub
41
+ from retrain_pipelines.dataset.tool_calls import \
42
+ count_tool_occurrences, plot_tools_occurences, \
43
+ column_words_stats, plot_words_count, \
44
+ get_unique_tools
45
+ from retrain_pipelines.utils.hf_utils import \
46
+ get_repo_version, get_new_repo_minor_version, \
47
+ push_files_to_hub_repo_branch
48
+
49
+ from retrain_pipelines.dag_engine.core import \
50
+ TaskPayload, task, dag, DagParam, ctx, UiCss
51
+
52
+ from retrain_pipelines.dag_engine.rp_logging import \
53
+ rp_redirect_stdout
54
+
55
+ from retrain_pipelines.dag_engine.sdk import \
56
+ ExecutionsIterator
57
+
58
+ from retrain_pipelines.utils import create_requirements
59
+
60
+
61
+ #--- helpers ----------------------------------------------------------------------------
62
+
63
+
64
+ logger = logging.getLogger(__name__)
65
+ logger.setLevel(logging.DEBUG)
66
+
67
+
68
+ class LocalServeReadinessEnum(Enum):
69
+ """
70
+ tracking local-serve (infra-validation)
71
+ status using a "3+"-states enum :
72
+ - "-1" for "not applicable"
73
+ (i.e. "model version not blessed"),
74
+ - "0/1" bool for failure/success.
75
+ """
76
+ NOT_APPLICABLE = -1
77
+ FAILURE = 0
78
+ FAILURE_NO_DOCKER = 2
79
+ SUCCESS = 1
80
+
81
+
82
+ def clear_gc():
83
+ """Convenience method to clear
84
+ the content of the garbage collector.
85
+ Forcing it to actually clear
86
+ any cuda tensor it holds.
87
+ """
88
+ for obj in gc.get_objects():
89
+ try:
90
+ if torch.is_tensor(obj) and obj.is_cuda:
91
+ del obj
92
+ except:
93
+ pass
94
+ gc.collect()
95
+
96
+
97
+ #--- retraining-pipeline elements -------------------------------------------------------
98
+
99
+
100
+ @task
101
+ def start() -> TaskPayload:
102
+ logger.info(f"{ctx.pipeline_name} - {ctx.exec_id}")
103
+ logging.getLogger("retrain_pipelines").setLevel(logging.INFO)
104
+
105
+ # inputs validation
106
+ repo_id_pattern = re.compile(
107
+ r"""
108
+ ^ # start
109
+ (?!.*\.\.) # no '..' anywhere
110
+ (?!.*--) # no '--' anywhere
111
+ (?: # legacy: single segment OR namespace/repo
112
+ [A-Za-z0-9._-]+ # legacy: gpt2, bert-base-uncased, etc.
113
+ |
114
+ [A-Za-z0-9._-]+/[A-Za-z0-9._-]+ # namespace/repo_name
115
+ )
116
+ $ # end
117
+ """,
118
+ re.VERBOSE
119
+ )
120
+ ctx.hf_dataset = json.loads(ctx.hf_dataset)
121
+ assert repo_id_pattern.match(ctx.hf_dataset["repo_id"]) is not None, \
122
+ f"Invalid repo_id format: {ctx.hf_dataset['repo_id']!r}"
123
+ ctx.augmentation_rate = float(ctx.augmentation_rate)
124
+ ctx.hf_enrich_dataset = json.loads(ctx.hf_enrich_dataset)
125
+ assert repo_id_pattern.match(ctx.hf_enrich_dataset["repo_id"]) is not None, \
126
+ f"Invalid repo_id format: {ctx.hf_enrich_dataset['repo_id']!r}"
127
+ ctx.enrichment_rate = float(ctx.enrichment_rate)
128
+ assert repo_id_pattern.match(ctx.dataset_repo_id) is not None, \
129
+ f"Invalid repo_id format: {dataset_repo_id!r}"
130
+ assert ctx.polars_engine in ["gpu", "cpu"]
131
+ ctx.hf_base_model = json.loads(ctx.hf_base_model)
132
+ assert repo_id_pattern.match(ctx.hf_base_model["repo_id"]) is not None, \
133
+ f"Invalid repo_id format: {ctx.hf_base_model['repo_id']!r}"
134
+ ctx.cpt_training_args = json.loads(ctx.cpt_training_args)
135
+ ctx.sft_training_args = json.loads(ctx.sft_training_args)
136
+ assert repo_id_pattern.match(ctx.model_repo_id) is not None, \
137
+ f"Invalid repo_id format: {model_repo_id!r}"
138
+
139
+ # GPU availability
140
+ logger.info(torch.cuda.get_device_name(0))
141
+ logger.info(torch.__version__)
142
+ ctx.engine = "cpu" if (
143
+ ctx.polars_engine == "gpu" and
144
+ not torch.cuda.is_available()
145
+ ) else ctx.polars_engine
146
+ logger.debug(f"Polars engine : {ctx.engine}")
147
+
148
+ # hf_dataset
149
+ hf_dataset_dict = \
150
+ get_lazy_df(
151
+ repo_id=ctx.hf_dataset["repo_id"],
152
+ commit_hash=ctx.hf_dataset["commit_hash"],
153
+ config_name=(
154
+ ctx.hf_dataset["config_name"] and
155
+ "" < ctx.hf_dataset["config_name"]
156
+ ),
157
+ hf_token=os.getenv("HF_TOKEN", None)
158
+ )
159
+ try:
160
+ logger.info(f"hf_dataset_dict lazy_df : {hf_dataset_dict['lazy_df']}")
161
+ logger.info(
162
+ f"{hf_dataset_dict['repo_id']}, " +
163
+ f"{hf_dataset_dict['commit_hash']} - " +
164
+ f"{hf_dataset_dict['commit_datetime']}\n" +
165
+ hf_dataset_dict["lazy_df"].explain()
166
+ )
167
+ except ComputeError as ex:
168
+ if "HF_TOKEN" not in os.environ:
169
+ logger.info("Does the Hugging Face-hosted dataset " +
170
+ "require authentication ?",
171
+ file=sys.stderr, flush=True)
172
+ raise ex
173
+ hf_dataset_version = get_repo_version(
174
+ repo_id=hf_dataset_dict["repo_id"],
175
+ revision=hf_dataset_dict["commit_hash"],
176
+ repo_type="dataset",
177
+ hf_token=os.getenv("HF_TOKEN", None)
178
+ )
179
+ hf_dataset_dict["version_label"] = (
180
+ f"{hf_dataset_version[0]}.{hf_dataset_version[1]}"
181
+ if sum(hf_dataset_version) > 0
182
+ else None
183
+ )
184
+ ctx.hf_dataset_dict = hf_dataset_dict
185
+
186
+ # hf_enrich_dataset
187
+ hf_enrich_dataset_dict = \
188
+ get_lazy_df(
189
+ repo_id=ctx.hf_enrich_dataset["repo_id"],
190
+ commit_hash=ctx.hf_enrich_dataset["commit_hash"],
191
+ config_name=(
192
+ ctx.hf_enrich_dataset["config_name"] and
193
+ "" < ctx.hf_enrich_dataset["config_name"]
194
+ ),
195
+ hf_token=os.getenv("HF_TOKEN", None)
196
+ )
197
+ hf_enrich_dataset_version = get_repo_version(
198
+ repo_id=hf_enrich_dataset_dict["repo_id"],
199
+ revision=hf_enrich_dataset_dict["commit_hash"],
200
+ repo_type="dataset",
201
+ hf_token=os.getenv("HF_TOKEN", None)
202
+ )
203
+ hf_enrich_dataset_dict["version_label"] = (
204
+ f"{hf_enrich_dataset_version[0]}.{hf_enrich_dataset_version[1]}"
205
+ if sum(hf_enrich_dataset_version) > 0
206
+ else None
207
+ )
208
+ logger.info(' ; '.join(f"{k}: {hf_enrich_dataset_dict[k]}"
209
+ for k in ['commit_hash',
210
+ 'commit_datetime']))
211
+ ctx.hf_enrich_dataset_dict = hf_enrich_dataset_dict
212
+
213
+ # hf_base_model
214
+ hf_base_model_revision=(
215
+ None if (rev_commit_hash:=ctx.hf_base_model["commit_hash"]) == ""
216
+ else rev_commit_hash
217
+ )
218
+ hf_base_model_commit = list_repo_commits(
219
+ repo_id=ctx.hf_base_model["repo_id"],
220
+ revision=hf_base_model_revision,
221
+ repo_type="model",
222
+ token=os.getenv("HF_TOKEN", None)
223
+ )[0]
224
+ # version major+minor=0 for non retrain-pipelines models
225
+ hf_base_model_version = get_repo_version(
226
+ repo_id=ctx.hf_base_model["repo_id"],
227
+ revision=hf_base_model_revision,
228
+ repo_type="model",
229
+ hf_token=os.getenv("HF_TOKEN", None)
230
+ )
231
+ ctx.hf_base_model_dict = {
232
+ "repo_id": ctx.hf_base_model["repo_id"],
233
+ "version_label": (
234
+ f"{hf_base_model_version[0]}.{hf_base_model_version[1]}"
235
+ if sum(hf_base_model_version) > 0
236
+ else None
237
+ ),
238
+ "commit_hash": hf_base_model_commit.commit_id,
239
+ "commit_datetime": \
240
+ hf_base_model_commit.created_at
241
+ }
242
+
243
+
244
+ ctx.model_version_blessed = False
245
+ ctx.current_blessed_exec = None
246
+ ctx.current_blessed_version_dict = None
247
+
248
+ ctx.retrain_pipelines = f"retrain-pipelines {__version__}"
249
+ ctx.retrain_pipeline_type = os.environ["retrain_pipeline_type"]
250
+
251
+
252
+ ctx.serving_artifacts_local_folder = os.path.realpath(os.path.join(
253
+ os.path.dirname(__file__), "..", "..", "serving_artifacts",
254
+ ctx.pipeline_name, str(ctx.exec_id)
255
+ ))
256
+
257
+ if not os.path.exists(ctx.serving_artifacts_local_folder):
258
+ os.makedirs(ctx.serving_artifacts_local_folder)
259
+
260
+
261
+ ctx.unsloth_dir = os.path.join(
262
+ ctx.serving_artifacts_local_folder,
263
+ "Unsloth"
264
+ )
265
+ logger.debug(f"unsloth_dir : {ctx.unsloth_dir}")
266
+ ctx.cpt_model_dir = os.path.join(ctx.unsloth_dir, "cpt_model")
267
+ ctx.sft_model_dir = os.path.join(ctx.unsloth_dir, "sft_model")
268
+
269
+ return None
270
+
271
+
272
+ @task
273
+ def eda(_) -> None:
274
+ """
275
+ exploratory data analysis.
276
+ """
277
+
278
+ ############################
279
+ # features and label #
280
+ # basic counts #
281
+ ############################
282
+ ctx.records_count = ctx.hf_dataset_dict["lazy_df"] \
283
+ .select(pl.len()).collect(engine=ctx.engine).item()
284
+ ctx.data_schema = get_column_info(
285
+ ctx.hf_dataset_dict["lazy_df"], engine=ctx.engine)
286
+ ############################
287
+
288
+ ############################
289
+ # Answers #
290
+ # tools count #
291
+ ############################
292
+ struct_schema = pl.Struct([
293
+ pl.Field("name",
294
+ pl.String
295
+ ),
296
+ pl.Field("arguments",
297
+ pl.List(pl.String) # we retrieve list of args names
298
+ # (without assigned values)
299
+ )
300
+ ])
301
+ tool_answer_occurrences_df = \
302
+ count_tool_occurrences(
303
+ ctx.hf_dataset_dict["lazy_df"],
304
+ ctx.hf_dataset["attributes"]["answers_attr"],
305
+ struct_schema) \
306
+ .collect(engine=ctx.engine)
307
+ print(f"{tool_answer_occurrences_df['occurrences'].sum():,} " +
308
+ f"query/tool-calls pairs")
309
+ fig = plot_tools_occurences(tool_answer_occurrences_df,
310
+ title_prefix="Dataset answers - ")
311
+ ctx.answers_tools_count_fig = fig
312
+ ############################
313
+
314
+ ############################
315
+ # Query #
316
+ # words count #
317
+ ############################
318
+ queries_max_length = ctx.hf_dataset_dict["lazy_df"].select(
319
+ pl.col(
320
+ ctx.hf_dataset["attributes"]["query_attr"]
321
+ ).str.len_chars().max().alias("max_query_length")
322
+ ).collect(engine=ctx.engine)
323
+ print(f"longuest query counts " +
324
+ f"{queries_max_length['max_query_length'][0]:,} characters")
325
+
326
+ # queries length quartiles
327
+ ctx.query_words_stats = \
328
+ column_words_stats(
329
+ ctx.hf_dataset_dict["lazy_df"],
330
+ ctx.hf_dataset["attributes"]["query_attr"]
331
+ ).collect(engine=ctx.engine)
332
+ print(ctx.query_words_stats.to_pandas().to_string(index=False))
333
+ print("Two thirds of the records have a query with less than " +
334
+ f"{ctx.query_words_stats['q3'][0]} words.")
335
+
336
+ fig = plot_words_count(
337
+ ctx.hf_dataset_dict["lazy_df"],
338
+ column_name=ctx.hf_dataset["attributes"]["query_attr"],
339
+ engine=ctx.engine)
340
+ ctx.words_count_fig = fig
341
+ ############################
342
+
343
+ ############################
344
+ # hf_enrich_dataset #
345
+ # Query words count #
346
+ ############################
347
+ enrich_question_words_stats = \
348
+ column_words_stats(
349
+ ctx.hf_enrich_dataset_dict['lazy_df'],
350
+ ctx.hf_enrich_dataset["query_attribute"],
351
+ column_attr_handler=eval(
352
+ ctx.hf_enrich_dataset["query_attribute_handler"])
353
+ ).collect(engine=ctx.engine)
354
+ print(enrich_question_words_stats.to_pandas()
355
+ .to_string(index=False))
356
+ del enrich_question_words_stats
357
+ ############################
358
+
359
+ return None
360
+
361
+
362
+ @task
363
+ def augment_data(_) -> None:
364
+ """
365
+ Add 'negative' examples, where
366
+ queries do not trigger any tool call.
367
+ To achieve that, we sample long user queries,
368
+ truncate at half words count, and
369
+ associate this to an empty list of tool-calls.
370
+ """
371
+ """
372
+ We only consider :
373
+ - records with longuest queries,
374
+ i.e. queries in the last quartile
375
+ of "queries with most word-counts"
376
+ (this is to avoid that 'truncated' queries
377
+ get really short)
378
+ - records with answers consisting
379
+ in a single tool-call
380
+ (in order to minimize the risk
381
+ that truncating actually gives
382
+ a valid answer with
383
+ one tool-call [or more])
384
+
385
+ Note on flow 'augmentation_rate' :
386
+ we add that many records (at most),
387
+ as quartiles size permits.
388
+ """
389
+
390
+ print("Sampling within the population with more than " +
391
+ str(ctx.query_words_stats['q3'][0]) +
392
+ " words (longest queries quartile) =>")
393
+
394
+ samples_count = \
395
+ int(ctx.records_count * ctx.augmentation_rate)
396
+ print(f"{ctx.augmentation_rate:.1%} would represent " +
397
+ f"{samples_count:,.0f} records to be sampled")
398
+
399
+ eligible_records_df = \
400
+ ctx.hf_dataset_dict["lazy_df"].filter(
401
+ pl.col(
402
+ ctx.hf_dataset["attributes"]["query_attr"]
403
+ )
404
+ .str.extract_all(r"\w+")
405
+ .map_elements(
406
+ lambda arr: len(arr),
407
+ return_dtype=pl.Int16)
408
+ .gt(ctx.query_words_stats['q3'][0])
409
+ & pl.col("answers")
410
+ .map_elements(
411
+ lambda x: len(json.loads(x)) == 1
412
+ if isinstance(x, str)
413
+ else False,
414
+ return_dtype=pl.Boolean)
415
+ ) \
416
+ .collect(engine=ctx.engine)
417
+ eligible_records_count = \
418
+ eligible_records_df.select(pl.len())["len"][0]
419
+ print(f"eligible_records_count : " +
420
+ f"{eligible_records_count:,.0f}")
421
+ samples_count = min(samples_count, eligible_records_count)
422
+ ctx.actual_augmentation_rate = \
423
+ samples_count / ctx.records_count
424
+ print("actual augmentation rate : " +
425
+ f"{ctx.actual_augmentation_rate:.1%}")
426
+ sampled_records_df = eligible_records_df.sample(
427
+ n=samples_count
428
+ )
429
+
430
+ ctx.augmented_records_df = \
431
+ sampled_records_df.with_columns(
432
+ pl.col("query")
433
+ .map_elements(
434
+ lambda query:
435
+ " ".join(
436
+ query.split()[
437
+ :len(query.split()) // 2]),
438
+ return_dtype=pl.Utf8)
439
+ .alias("truncated_query")
440
+ ).select([
441
+ pl.col("truncated_query").alias("query"),
442
+ pl.lit("[]").alias("answers")
443
+ ])
444
+ print(ctx.augmented_records_df.height,
445
+ ctx.augmented_records_df.columns)
446
+
447
+ return None
448
+
449
+
450
+ @task
451
+ def enrich_data(_) -> None:
452
+ """
453
+ Further enrich our dataset with 'negative' records from
454
+ another dataset (can be general-purpose text dataset)
455
+ as specified by the the flow 'hf_enrich_dataset' argument.
456
+ """
457
+ """
458
+ Note : we here use the Hugging Face `datasets` library
459
+ in 'streaming' mode for records sampling.
460
+ """
461
+
462
+ hf_enrich_ds = load_dataset(
463
+ path=ctx.hf_enrich_dataset["repo_id"],
464
+ name=ctx.hf_enrich_dataset["config_name"],
465
+ revision=ctx.hf_enrich_dataset_dict["commit_hash"],
466
+ streaming=True)
467
+ print(hf_enrich_ds["train"])
468
+
469
+ samples_count = \
470
+ int(ctx.records_count * ctx.enrichment_rate)
471
+ print(f"Samplig {samples_count:,.0f} records")
472
+
473
+ query_attribute_handler = \
474
+ eval(ctx.hf_enrich_dataset["query_attribute_handler"])
475
+ samples_iterator = iterable_dataset_multi_buffer_sampler(
476
+ hf_enrich_ds["train"],
477
+ total_samples=samples_count,
478
+ attributes_selector=\
479
+ (lambda x:query_attribute_handler(
480
+ x[ctx.hf_enrich_dataset["query_attribute"]])),
481
+ buffer_size=3_000,
482
+ num_passes=3,
483
+ seed=None
484
+ )
485
+ # Capitalize and add end punctuation if missing
486
+ start_time = time.time()
487
+ print("Starting sample enriching records, " +
488
+ "this may take some time if the source dataset " +
489
+ "has a complex structure..")
490
+ samples_list = [
491
+ s.capitalize() + ("" if s[-1] in ".!?" else "?")
492
+ for s in samples_iterator]
493
+ elapsed_time = time.time() - start_time
494
+ print(f".. sampling completed " +
495
+ f"({int(elapsed_time // 3_600)}h:" +
496
+ f"{int((elapsed_time % 3_600) // 60)}m:" +
497
+ f"{int(elapsed_time % 60)}s).")
498
+ enriched_records_df = pl.DataFrame(
499
+ {"query": samples_list,
500
+ "answers": \
501
+ ["[]"] * \
502
+ len(samples_list)}
503
+ )
504
+ ctx.enriched_records_df = enriched_records_df
505
+
506
+ return None
507
+
508
+
509
+ @task(ui_css=UiCss(background="#FF9900", color="#111827", border="#1F2937"))
510
+ def dataset_to_hub(_) -> None:
511
+ """
512
+ Push to hub dataset version
513
+ - continued pre-training dataset
514
+ - training and validation splits of the
515
+ augmented and enriched
516
+ supervised finetuning dataset
517
+ - readme with versioning info
518
+ """
519
+
520
+ #############################
521
+ # case of user-provided #
522
+ # documentation artifact(s) #
523
+ #############################
524
+ # note that user can provide either
525
+ # 'pipeline_card.py' or 'template.html'
526
+ # or 'dataset_readme.py'
527
+ # or 'dataset_readme_template.md'
528
+ # or 'model_readme.py'
529
+ # or 'model_readme_template.md'
530
+ # or any combination of those
531
+ # when specifying custom
532
+ # 'pipeline_card_artifacts_path'
533
+ if (
534
+ "dataset_readme_template.md" in
535
+ os.listdir(ctx.pipeline_card_artifacts_path)
536
+ ):
537
+ template_dir = ctx.pipeline_card_artifacts_path
538
+ else:
539
+ template_dir = os.path.dirname(
540
+ importlib.util.find_spec(
541
+ f"retrain_pipelines.pipeline_card."+
542
+ f"{os.getenv('retrain_pipeline_type')}"
543
+ ).origin)
544
+ print(f"template_dir : '{template_dir}'")
545
+ #############################
546
+ if "dataset_readme.py" in os.listdir(
547
+ ctx.pipeline_card_artifacts_path):
548
+ from retrain_pipelines.utils import \
549
+ get_get_dataset_readme_content
550
+ get_dataset_readme_content = \
551
+ get_get_dataset_readme_content(
552
+ ctx.pipeline_card_artifacts_path)
553
+ else:
554
+ from retrain_pipelines.pipeline_card import \
555
+ get_dataset_readme_content
556
+ #############################
557
+
558
+
559
+ #############################
560
+ # augmented & enriched #
561
+ # finetuning dataset #
562
+ #############################
563
+ merged_df = pl.concat([
564
+ # dataset
565
+ ctx.hf_dataset_dict["lazy_df"].select([
566
+ ctx.hf_dataset["attributes"]["query_attr"],
567
+ ctx.hf_dataset["attributes"]["answers_attr"]
568
+ ]).collect(engine=ctx.engine),
569
+ # truncated queries augmentation
570
+ ctx.augmented_records_df,
571
+ # enriching dataset
572
+ ctx.enriched_records_df
573
+ ]).sample(
574
+ # shuffling
575
+ fraction=1,
576
+ shuffle=True,
577
+ with_replacement=False
578
+ )
579
+ merged_df = merged_df.sample(fraction=1, shuffle=True)
580
+ merged_df.rechunk()
581
+ print(("merged_df", f"{merged_df.shape[0]:,.0F}",
582
+ merged_df.columns))
583
+
584
+ pandas_df = merged_df.to_pandas()
585
+ train_size = int(0.8 * len(pandas_df))
586
+ print(f"validation : {len(pandas_df) - train_size}")
587
+ sft_dataset = DatasetDict({
588
+ "train": Dataset.from_pandas(pandas_df[:train_size]),
589
+ "validation": Dataset.from_pandas(pandas_df[train_size:])
590
+ })
591
+ #############################
592
+
593
+ #############################
594
+ # continued pre-training #
595
+ # dataset #
596
+ #############################
597
+ struct_schema = pl.Struct([
598
+ pl.Field("name", pl.String),
599
+ pl.Field("description", pl.String),
600
+ pl.Field(
601
+ "parameters",
602
+ pl.String # Use String to allow
603
+ # for varying structures
604
+ # (different tools indeed having
605
+ # different sets of parameters
606
+ # i.e. different parameters counts,
607
+ # datatypes and names)
608
+ # so parsing must be tolerant.
609
+ )
610
+ ])
611
+ unique_tools_df = get_unique_tools(
612
+ ctx.hf_dataset_dict["lazy_df"],
613
+ tools_attr_name=\
614
+ ctx.hf_dataset["attributes"]["tools_attr"],
615
+ struct_schema=struct_schema
616
+ ).collect(engine=ctx.engine)
617
+ unique_tools_arrow_table = unique_tools_df.to_arrow()
618
+ ctx.unique_tools_dataset = \
619
+ Dataset(unique_tools_arrow_table)
620
+ print(ctx.unique_tools_dataset)
621
+ #############################
622
+
623
+ #############################
624
+ # DatasetDict #
625
+ # with multiple tables #
626
+ #############################
627
+ dataset_dict = DatasetDict({
628
+ "continued_pre_training": \
629
+ ctx.unique_tools_dataset,
630
+ "supervised_finetuning": sft_dataset
631
+ })
632
+ print(dataset_dict, flush=True)
633
+ #############################
634
+
635
+ #############################
636
+ # dataset README #
637
+ # from template #
638
+ #############################
639
+ commit_datetime = datetime.utcnow()
640
+ new_dataset_version_label = get_new_repo_minor_version(
641
+ repo_id=ctx.dataset_repo_id,
642
+ repo_type="dataset",
643
+ hf_token=os.getenv("HF_TOKEN", None))
644
+ readme_content = get_dataset_readme_content(
645
+ template_folder=template_dir,
646
+
647
+ hf_dataset_dict=ctx.hf_dataset_dict,
648
+ hf_enrich_dataset_dict=ctx.hf_enrich_dataset_dict,
649
+ dataset_dict=dataset_dict,
650
+
651
+ augmentation_rate=ctx.actual_augmentation_rate,
652
+ enrichment_rate=ctx.enrichment_rate,
653
+
654
+ version_label=new_dataset_version_label,
655
+ commit_datetime=commit_datetime,
656
+
657
+ pipeline_name=ctx.pipeline_name,
658
+ exec_id=ctx.exec_id,
659
+ engine=ctx.engine
660
+ )
661
+ #############################
662
+
663
+ dataset_commit_hash = push_dataset_version_to_hub(
664
+ repo_id=ctx.dataset_repo_id,
665
+ version_label=new_dataset_version_label,
666
+ timestamp_str=commit_datetime.strftime(
667
+ "%Y-%m-%d %H:%M:%S UTC"),
668
+ dataset_dict=dataset_dict,
669
+ dataset_readme_content=readme_content,
670
+ hf_token=os.getenv("HF_TOKEN", None)
671
+ )
672
+ if not dataset_commit_hash:
673
+ raise Exception(
674
+ "Failed to publish dataset version.")
675
+ print(f"https://huggingface.co/datasets/{ctx.dataset_repo_id}" +
676
+ f"/blob/{dataset_commit_hash}/README.md")
677
+ ctx.dataset_commit_dict = {
678
+ "repo_id": ctx.dataset_repo_id,
679
+ "commit_hash": dataset_commit_hash,
680
+ "version_label": new_dataset_version_label,
681
+ "commit_datetime": commit_datetime,
682
+ }
683
+
684
+ return None
685
+
686
+
687
+ @task
688
+ def continued_pre_training(_) -> None:
689
+ """
690
+ Gives the base model some additional intrinsic knowkledge
691
+ through continued pre-training.
692
+ See unsloth.ai/blog/contpretraining
693
+ """
694
+ from retrain_pipelines.model.hf_utils import \
695
+ plot_log_history
696
+
697
+ #######################################
698
+ # base-model and associated tokenizer #
699
+ # from Hub (or local cache) #
700
+ #######################################
701
+ ctx.max_seq_length = 2048
702
+ model, tokenizer = FastLanguageModel.from_pretrained(
703
+ model_name=ctx.hf_base_model_dict["repo_id"],
704
+ revision=ctx.hf_base_model_dict["commit_hash"],
705
+ max_seq_length=ctx.max_seq_length,
706
+ dtype=None,
707
+ load_in_4bit=False,
708
+ # case of a gated or private base-model
709
+ token=os.getenv("HF_TOKEN", None)
710
+ )
711
+ #######################################
712
+
713
+ #######################################
714
+ # dataset prompt_template mapping #
715
+ #######################################
716
+ tools_dataset = DatasetDict(
717
+ {"train": ctx.unique_tools_dataset})
718
+ print(tools_dataset)
719
+ tool_prompt_template = "tool: {}"
720
+ def formatting_prompts_func(tools_batch):
721
+ tools_batch = tools_batch["tool"]
722
+ outputs = []
723
+ for tool in tools_batch:
724
+ # Must add EOS_TOKEN,
725
+ # otherwise generation will go on forever!
726
+ text = tool_prompt_template.format(tool) + \
727
+ tokenizer.eos_token
728
+ outputs.append(text)
729
+ return { "tools" : outputs, }
730
+ cpt_dataset = tools_dataset["train"].map(
731
+ formatting_prompts_func, batched=True,)
732
+ #######################################
733
+
734
+ #######################################
735
+ # PEFT adapter #
736
+ # for continued pre-training #
737
+ #######################################
738
+ model = FastLanguageModel.get_peft_model(
739
+ model,
740
+ r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
741
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
742
+ "gate_proj", "up_proj", "down_proj",
743
+ # Add for continued pretraining
744
+ "embed_tokens", "lm_head",],
745
+ lora_alpha = 32,
746
+ lora_dropout = 0, # Supports any, 0 is optimized
747
+ bias = "none", # Supports any, "none" is optimized
748
+ # True or "unsloth" for very long context
749
+ use_gradient_checkpointing = "unsloth",
750
+ use_rslora = True, # rank-stabilized LoRA
751
+ loftq_config = None, # LoftQ
752
+ #random_state = 3407,
753
+ )
754
+ #######################################
755
+
756
+ #######################################
757
+ # cpt_trainer #
758
+ #######################################
759
+ if (
760
+ "records_cap" in ctx.cpt_training_args and
761
+ ctx.cpt_training_args["records_cap"] is not None and
762
+ isinstance(ctx.cpt_training_args["records_cap"], int)
763
+ ):
764
+ cpt_dataset = cpt_dataset.take(
765
+ ctx.cpt_training_args["records_cap"])
766
+ print(f"cpt_dataset : {cpt_dataset}")
767
+
768
+ train_args = UnslothTrainingArguments(
769
+ # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_strategy
770
+ per_device_train_batch_size=2,
771
+ gradient_accumulation_steps=8,
772
+
773
+ **{k: v for k, v in ctx.cpt_training_args.items()
774
+ if k != "records_cap"},
775
+
776
+ # 2 to 10x smaller learning rate
777
+ # for the embedding matrices
778
+ learning_rate=5e-5,
779
+ embedding_learning_rate=1e-5,
780
+
781
+ fp16=not is_bfloat16_supported(),
782
+ bf16=is_bfloat16_supported(),
783
+ logging_steps=1,
784
+ optim="adamw_8bit",
785
+ weight_decay=0.01,
786
+ lr_scheduler_type="linear",
787
+ #seed=3407,
788
+
789
+ output_dir=os.path.join(
790
+ ctx.unsloth_dir, "outputs", "cpt"),
791
+ save_total_limit = 2,
792
+
793
+ report_to="tensorboard",
794
+ logging_dir=os.path.join(
795
+ ctx.sft_model_dir,
796
+ "runs", "cpt")
797
+ )
798
+
799
+ trainer = UnslothTrainer(
800
+ model=model, tokenizer=tokenizer,
801
+ train_dataset=cpt_dataset,
802
+ dataset_text_field="tools",
803
+ max_seq_length=ctx.max_seq_length,
804
+ dataset_num_proc=2,
805
+ args=train_args,
806
+ )
807
+ #######################################
808
+
809
+ #######################################
810
+ # Show current memory stats #
811
+ #######################################
812
+ torch.cuda.ipc_collect()
813
+ torch.cuda.empty_cache()
814
+ _ = gc.collect()
815
+
816
+ gpu_stats = torch.cuda.get_device_properties(0)
817
+ ctx.start_gpu_memory = \
818
+ round(torch.cuda.max_memory_reserved()
819
+ / 1024 / 1024 / 1024, 3)
820
+ ctx.max_memory = \
821
+ round(gpu_stats.total_memory
822
+ / 1024 / 1024 / 1024, 3)
823
+ print(f"GPU = {gpu_stats.name}. " +
824
+ f"Max memory = {ctx.max_memory} GB.")
825
+ print(f"{ctx.start_gpu_memory} GB of memory reserved.")
826
+ #######################################
827
+
828
+ ctx.cpt_traces_file_fullname = os.path.join(
829
+ ctx.unsloth_dir, "cpt_trainer_traces.txt")
830
+ logger.info(
831
+ "Training started. " +
832
+ f"Check [underline]{ctx.cpt_traces_file_fullname}[/] for live traces " +
833
+ "or go watch your [white bold]TensorBoard[/] charts live updates !"
834
+ )
835
+ with open(ctx.cpt_traces_file_fullname, 'w') as f:
836
+ with rp_redirect_stdout(f):
837
+ trainer_stats = trainer.train()
838
+ print(f"{trainer_stats.metrics['train_runtime']} " +
839
+ f"seconds used for training " +
840
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
841
+ f" minutes).")
842
+
843
+ ctx.cpt_log_history = trainer.state.log_history
844
+ ctx.cpt_log_history_fig = \
845
+ plot_log_history(
846
+ ctx.cpt_log_history,
847
+ title="Continued pretraining loss"
848
+ )
849
+ del trainer
850
+ # logger.debug(f"Continued pretraining loss curve : {ctx.cpt_log_history}")
851
+
852
+ model.save_pretrained_merged(
853
+ save_directory=ctx.cpt_model_dir,
854
+ tokenizer=tokenizer,
855
+ save_method="lora"
856
+ )
857
+ print(f"cpt_model_dir : {ctx.cpt_model_dir}\n")
858
+
859
+ # vRAM & RAM cleanup
860
+ # (incl. force-delete all CUDA tensors in gc)
861
+ del model
862
+ del tokenizer
863
+ clear_gc()
864
+ torch.cuda.empty_cache()
865
+ torch.cuda.synchronize()
866
+ print(f"After cleanup: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
867
+
868
+ return None
869
+
870
+
871
+ @task
872
+ def supervised_finetuning(_) -> None:
873
+ """
874
+ Trains the model on tool-calling
875
+ task specialization.
876
+ """
877
+ from retrain_pipelines.model.hf_utils import \
878
+ plot_log_history
879
+
880
+ model, tokenizer = FastLanguageModel.from_pretrained(
881
+ model_name=ctx.cpt_model_dir,
882
+ max_seq_length=ctx.max_seq_length,
883
+ dtype=None,
884
+ load_in_4bit=False,
885
+ )
886
+ # !!!! bug fix BEGIN !!!!
887
+ # otherwise, 'embed_tokens' and 'lm_head'
888
+ # trained during CPT are "ignored",
889
+ # i.e. not saved after SFT
890
+ # (note that, alternatively, we could also
891
+ # do this fix after sft-training and
892
+ # just before saving ;
893
+ # which would be equivalent to
894
+ # freezing embeddings during finetuning
895
+ # for better pretrained knowledge retention)
896
+ # @see https://www.reddit.com/r/unsloth/comments/1dtzcd6/fastlanguagemodelpatch_peft_model_changing/
897
+ model.model.model.embed_tokens.modules_to_save.default.to(
898
+ device="cuda:0",
899
+ dtype=torch.float32,
900
+ non_blocking=True)
901
+ model.model.model.embed_tokens.modules_to_save.default \
902
+ .requires_grad_(True)
903
+ model.model.lm_head.modules_to_save.default.to(
904
+ device="cuda:0",
905
+ dtype=torch.float32,
906
+ non_blocking=True)
907
+ model.model.lm_head.modules_to_save.default \
908
+ .requires_grad_(True)
909
+ # !!!! bug fix END !!!!
910
+
911
+ #######################################
912
+ # dataset prompt_template mapping #
913
+ #######################################
914
+ # download from Hub (or get from local cache)
915
+ queries_dataset = load_dataset(
916
+ path=ctx.dataset_commit_dict["repo_id"],
917
+ name="supervised_finetuning",
918
+ revision=ctx.dataset_commit_dict["commit_hash"],
919
+ token=os.getenv("HF_TOKEN", None))
920
+ print(f"HF_DATASETS_CACHE : {HF_DATASETS_CACHE}") # HF_CACHE_HOME
921
+ ctx.sft_prompt_template = dedent("""
922
+ You specialize in generating tool calls. Given a query, your task is to return a list of tool calls based on your knowledge of known tools.
923
+
924
+ Rules:
925
+ 1. You can only use tools you know. Do not create new tools under any circumstances.
926
+ 2. If a query does not match any known tool, return an empty list ([]).
927
+ 3. If information is missing to use a known tool, do not attempt to use it.
928
+ 4. Your response must always be a valid JSON array, and nothing else.
929
+
930
+ Be precise and do not guess.
931
+
932
+ # query:
933
+ {}
934
+ # response:
935
+ {}
936
+ """).strip()
937
+ tokenizer.chat_template = ctx.sft_prompt_template
938
+
939
+ EOS_TOKEN = tokenizer.eos_token
940
+ def formatting_prompts_func(records):
941
+ query = records["query"]
942
+ tools = records["answers"]
943
+ outputs = []
944
+ for query, tools in zip(query, tools):
945
+ # Must add EOS_TOKEN,
946
+ # otherwise your generation will go on forever
947
+ text = ctx.sft_prompt_template.format(query, tools) \
948
+ + EOS_TOKEN
949
+ outputs.append(text)
950
+ return { "text" : outputs, }
951
+ sft_train_dataset = queries_dataset["train"].map(
952
+ formatting_prompts_func, batched=True)
953
+ sft_valid_dataset = queries_dataset["validation"].map(
954
+ formatting_prompts_func, batched=True,)
955
+ #######################################
956
+
957
+ #######################################
958
+ # PEFT adapter #
959
+ # for supervised finetuning #
960
+ #######################################
961
+ # for cases where CPT has been merged into overall model
962
+ # otherwize, keep on training current LoRa adapter
963
+ # model = FastLanguageModel.get_peft_model(
964
+ # model,
965
+ # r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
966
+ # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
967
+ # "gate_proj", "up_proj", "down_proj"],
968
+ # lora_alpha = 32,
969
+ # lora_dropout = 0, # Supports any, but = 0 is optimized
970
+ # bias = "none", # Supports any, but = "none" is optimized
971
+ # # True or "unsloth" for very long context
972
+ # use_gradient_checkpointing = "unsloth",
973
+ # random_state = 3407,
974
+ # use_rslora = True, # rank stabilized LoRA
975
+ # loftq_config = None, # LoftQ
976
+ # )
977
+ #######################################
978
+
979
+ #######################################
980
+ # sft_trainer #
981
+ #######################################
982
+ split = sft_train_dataset.train_test_split(
983
+ test_size=1000,
984
+ #seed=42
985
+ )
986
+ train_dataset = split['train']
987
+ eval_dataset = split['test']
988
+ if (
989
+ "records_cap" in ctx.sft_training_args and
990
+ ctx.sft_training_args["records_cap"] is not None and
991
+ isinstance(ctx.sft_training_args["records_cap"], int)
992
+ ):
993
+ train_dataset = train_dataset.take(
994
+ ctx.sft_training_args["records_cap"])
995
+ eval_dataset = eval_dataset.take(
996
+ ctx.sft_training_args["records_cap"])
997
+ print(f"train_dataset : {train_dataset}")
998
+ print(f"eval_dataset : {eval_dataset}")
999
+
1000
+ train_args = UnslothTrainingArguments(
1001
+ per_device_train_batch_size=2,
1002
+ gradient_accumulation_steps=8,
1003
+
1004
+ **{k: v for k, v in ctx.sft_training_args.items()
1005
+ if k != "records_cap"},
1006
+
1007
+ per_device_eval_batch_size=2,
1008
+ eval_steps=200,
1009
+ eval_strategy="steps",
1010
+ do_eval=True,
1011
+
1012
+ learning_rate=5e-5,
1013
+ # embedding_learning_rate=1e-5, # Optionally here
1014
+
1015
+ fp16=not is_bfloat16_supported(),
1016
+ bf16=is_bfloat16_supported(),
1017
+
1018
+ optim="adamw_8bit",
1019
+ weight_decay=0.00,
1020
+ lr_scheduler_type="linear",
1021
+ #seed=3407,
1022
+
1023
+ output_dir=os.path.join(
1024
+ ctx.unsloth_dir, "outputs", "sft"),
1025
+ save_total_limit=2,
1026
+
1027
+ disable_tqdm=True,
1028
+ logging_steps=1,
1029
+ report_to="tensorboard",
1030
+ logging_dir=os.path.join(
1031
+ ctx.sft_model_dir,
1032
+ "runs", "sft")
1033
+ )
1034
+
1035
+ trainer = UnslothTrainer(
1036
+ model=model, tokenizer=tokenizer,
1037
+ train_dataset=train_dataset,
1038
+ dataset_text_field="text",
1039
+ eval_dataset=eval_dataset,
1040
+ max_seq_length=ctx.max_seq_length,
1041
+ dataset_num_proc=8,
1042
+ args=train_args
1043
+ )
1044
+ trainer.can_return_loss = True
1045
+ #######################################
1046
+
1047
+ #######################################
1048
+ # Show current memory stats #
1049
+ #######################################
1050
+ torch.cuda.ipc_collect()
1051
+ torch.cuda.empty_cache()
1052
+ _ = gc.collect()
1053
+
1054
+ used_memory = \
1055
+ round(torch.cuda.max_memory_reserved()
1056
+ /1024/1024/1024, 3)
1057
+ used_memory_for_lora = \
1058
+ round(used_memory-ctx.start_gpu_memory, 3)
1059
+ used_percentage = \
1060
+ round(used_memory/ctx.max_memory*100, 3)
1061
+ lora_percentage = \
1062
+ round(used_memory_for_lora/ctx.max_memory*100,
1063
+ 3)
1064
+ print(f"Peak reserved memory = " +
1065
+ f"{used_memory} GB.")
1066
+ print(f"Peak reserved memory for " +
1067
+ f"training = {used_memory_for_lora} " +
1068
+ f"GB.")
1069
+ print(f"Peak reserved memory % of " +
1070
+ f"max memory = {used_percentage} %.")
1071
+ print(f"Peak reserved memory for training " +
1072
+ f"% of max memory = {lora_percentage} %.")
1073
+ #######################################
1074
+
1075
+ ctx.sft_traces_file_fullname = os.path.join(
1076
+ ctx.unsloth_dir, "sft_trainer_traces.txt")
1077
+ logger.info(
1078
+ "Training started. " +
1079
+ f"Check [underline]{ctx.sft_traces_file_fullname}[/] for live traces " +
1080
+ "or go watch your [white bold]TensorBoard[/] charts live updates !"
1081
+ )
1082
+ with open(ctx.sft_traces_file_fullname, 'w') as f:
1083
+ with rp_redirect_stdout(f):
1084
+ trainer_stats = trainer.train()
1085
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1086
+ f"seconds used for training " +
1087
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1088
+ f" minutes).")
1089
+
1090
+ ctx.sft_log_history = trainer.state.log_history
1091
+ ctx.sft_log_history_fig = \
1092
+ plot_log_history(
1093
+ ctx.sft_log_history,
1094
+ title="Supervised finetuning loss"
1095
+ )
1096
+ del trainer
1097
+
1098
+ model.save_pretrained_merged(
1099
+ ctx.sft_model_dir, tokenizer,
1100
+ save_method = "lora"
1101
+ )
1102
+ print(f"sft_model_dir : {ctx.sft_model_dir}\n")
1103
+
1104
+ # vRAM & RAM cleanup
1105
+ # (incl. force-delete all CUDA tensors in gc)
1106
+ del model
1107
+ del tokenizer
1108
+ clear_gc()
1109
+ torch.cuda.empty_cache()
1110
+ torch.cuda.synchronize()
1111
+ print(f"After cleanup: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
1112
+
1113
+ return None
1114
+
1115
+
1116
+ @task
1117
+ def evaluate_model(_) -> None:
1118
+ """
1119
+ Batch inference on the SFT validation dataset.
1120
+ """
1121
+ from retrain_pipelines.model import \
1122
+ infer_validation, compute_counts_n_metrics, \
1123
+ plot_validation_completions
1124
+
1125
+ ######################################################
1126
+ # loading trained adapter #
1127
+ ######################################################
1128
+ # Unsloth [and hf transformers before it] #
1129
+ # (if loading both model & tokenizer at once #
1130
+ # same as we did in prior tasks, but now #
1131
+ # with tokenizer.chat_template being set #
1132
+ # in tokenizer.config) is forcing on us some kind of #
1133
+ # chat_template format hard-requirements. #
1134
+ ######################################################
1135
+ # load base from cache
1136
+ # (with base tokenizer, which we ignore)
1137
+ model, _ = FastLanguageModel.from_pretrained(
1138
+ model_name=ctx.hf_base_model_dict["repo_id"],
1139
+ revision=ctx.hf_base_model_dict["commit_hash"],
1140
+ max_seq_length=ctx.max_seq_length,
1141
+ dtype=None,
1142
+ load_in_4bit=False,
1143
+ # case of a gated or private base-model
1144
+ token=os.getenv("HF_TOKEN", None)
1145
+ )
1146
+ model = FastLanguageModel.for_inference(model)
1147
+ # load our CPT+SFT trained & locally-saved adapter
1148
+ model.load_adapter(peft_model_id=ctx.sft_model_dir)
1149
+ # Separately load our (potentially trained &)
1150
+ # locally-saved adapter-tokenizer
1151
+ # (loading it below via HF and not Unsloth)
1152
+ tokenizer = AutoTokenizer.from_pretrained(
1153
+ pretrained_model_name_or_path=ctx.sft_model_dir
1154
+ )
1155
+ ######################################################
1156
+
1157
+ ######################################################
1158
+ # validation dataset #
1159
+ ######################################################
1160
+ # download from Hub (or get from local cache)
1161
+ queries_dataset = load_dataset(
1162
+ path=ctx.dataset_commit_dict["repo_id"],
1163
+ name="supervised_finetuning",
1164
+ revision=ctx.dataset_commit_dict["commit_hash"],
1165
+ token=os.getenv("HF_TOKEN", None))
1166
+ if (
1167
+ "records_cap" in ctx.sft_training_args and
1168
+ ctx.sft_training_args["records_cap"] is not None and
1169
+ isinstance(ctx.sft_training_args["records_cap"], int)
1170
+ ):
1171
+ validation_data = queries_dataset["validation"].take(
1172
+ ctx.sft_training_args["records_cap"])
1173
+ else:
1174
+ validation_data = queries_dataset["validation"]
1175
+ print(validation_data, flush=True)
1176
+ ######################################################
1177
+
1178
+ ctx.max_new_tokens = 400
1179
+ start_time = time.time()
1180
+ validation_results = infer_validation(
1181
+ tokenizer=tokenizer,
1182
+ model=model,
1183
+ validation_data=validation_data,
1184
+ prompt_template=tokenizer.chat_template,
1185
+ batch_size=32, # 64,
1186
+ queries_attr_name=\
1187
+ ctx.hf_dataset["attributes"]["query_attr"],
1188
+ answers_attr_name=\
1189
+ ctx.hf_dataset["attributes"]["answers_attr"],
1190
+ max_new_tokens=ctx.max_new_tokens,
1191
+ device="cuda"
1192
+ )
1193
+ print("infer_validation - Elapsed time: " +
1194
+ f"{(time.time() - start_time):.2f} seconds")
1195
+ ctx.validation_results = validation_results # <= to artifacts store
1196
+
1197
+ eval_df = pl.LazyFrame(validation_results)
1198
+
1199
+ records = eval_df.with_columns(
1200
+ (pl.col("answer") == pl.col("completion")) \
1201
+ .alias("is_ground_truth_identical")
1202
+ ).collect() #engine=ctx.engine)
1203
+ print("perfect characters-match accuracy : " +
1204
+ str(records['is_ground_truth_identical'].mean()))
1205
+
1206
+ eval_metrics_df = compute_counts_n_metrics(
1207
+ eval_df, is_format_fault_tolerant=True)
1208
+ overall_metrics_df = eval_metrics_df.select([
1209
+ pl.col("precision").mean(),
1210
+ pl.col("recall").mean(),
1211
+ pl.col("f1").mean(),
1212
+ pl.col("jaccard").mean()
1213
+ ]).collect() #engine=ctx.engine)
1214
+ ctx.perf_metrics = overall_metrics_df.row(0, named=True)
1215
+ print(ctx.perf_metrics)
1216
+
1217
+ ctx.validation_completions_fig = \
1218
+ plot_validation_completions(
1219
+ eval_metrics_df, engine=ctx.engine)
1220
+
1221
+ # vRAM & RAM cleanup
1222
+ # (incl. force-delete all CUDA tensors in gc)
1223
+ del model
1224
+ del tokenizer
1225
+ clear_gc()
1226
+ torch.cuda.empty_cache()
1227
+ torch.cuda.synchronize()
1228
+ print(f"After cleanup: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
1229
+
1230
+ return None
1231
+
1232
+
1233
+ @task
1234
+ def model_version_blessing(_) -> None:
1235
+ """
1236
+ Comparing newly-retrained model version
1237
+ against best-performing predecessor.
1238
+ """
1239
+ """
1240
+ Note: for Hugging Face integrated pipelines,
1241
+ we compare against lastest commit of main branch
1242
+ of the model repository there.
1243
+ When it comes to local "mf_run_id" of the pipeline run
1244
+ having generated that best prior model version
1245
+ (retrieved from model card metadata from HF yaml section),
1246
+ we check against records of the herein ML-framework instance,
1247
+ as "prior best version" of the model here beign retrained
1248
+ may have been originated from another one
1249
+ than the one executing the current retraining
1250
+ (in which case, we simply don't includ a "local" hyperlink
1251
+ in the model version pipeline_cards that will be
1252
+ produced later in the herein pipeline run).
1253
+ """
1254
+ from retrain_pipelines.model.hf_utils import \
1255
+ current_blessed_model_version_dict
1256
+
1257
+ main_perf_metric_name = "jaccard"
1258
+
1259
+ current_blessed_version_dict = \
1260
+ current_blessed_model_version_dict(
1261
+ repo_id=ctx.model_repo_id,
1262
+ hf_token=os.getenv("HF_TOKEN", None)
1263
+ )
1264
+ print("current_blessed_version_dict : " +
1265
+ str(current_blessed_version_dict))
1266
+
1267
+ if current_blessed_version_dict is None:
1268
+ print("case 'no prior blessed model version found"
1269
+ " => blessing.'")
1270
+ ctx.model_version_blessed = True
1271
+
1272
+ elif (
1273
+ main_perf_metric_name in
1274
+ current_blessed_version_dict["perf_metrics"]
1275
+ ):
1276
+ current_blessed_exec_id = \
1277
+ current_blessed_version_dict["exec_id"]
1278
+ print(f"current_blessed_exec_id : {current_blessed_exec_id}")
1279
+ current_blessed_metric_value = \
1280
+ current_blessed_version_dict[
1281
+ "perf_metrics"][main_perf_metric_name]
1282
+
1283
+ ctx.model_version_blessed = (
1284
+ ctx.perf_metrics[main_perf_metric_name] >=
1285
+ current_blessed_metric_value
1286
+ )
1287
+
1288
+ if not ctx.model_version_blessed:
1289
+ ctx.current_blessed_version_dict = \
1290
+ current_blessed_version_dict
1291
+ # may have failed after the "pipeline_card" task,
1292
+ # so we do not filter on success
1293
+ for execution in ExecutionsIterator(
1294
+ exec_name=ctx.pipeline_name,
1295
+ page_size=10
1296
+ ):
1297
+ if str(execution.id) == current_blessed_exec_id:
1298
+ # Has the execution seen task "pipeline_card" which
1299
+ # completed successfully
1300
+ # ("execution" has generated a custom pipeline-card artifact) ?
1301
+ # If not, hyperlink generation will later fail.
1302
+ run_has_custom_card_artifact = (len([
1303
+ t for t in execution.get_tasks_with_name(
1304
+ task_type_name="pipeline_card")
1305
+ if t.end_timestamp and t.success
1306
+ ]) == 1)
1307
+ if not run_has_custom_card_artifact:
1308
+ print(
1309
+ f"Execution #{current_blessed_exec_id} " +
1310
+ "Doesn't seem to have successfully " +
1311
+ "generated a pipeline-card artifact.",
1312
+ file=sys.stderr, flush=True)
1313
+
1314
+ else:
1315
+ # further filtering on successful executions that are
1316
+ # retraining of a prior version of the same model
1317
+ # (to minimize the risk that this was obtained
1318
+ # on another DAG-engine instance)
1319
+ if (
1320
+ execution.get_attr("model_version_blessed") and
1321
+ execution.get_attr("model_repo_id") or "" == \
1322
+ ctx.model_repo_id
1323
+ ):
1324
+ ctx.current_blessed_exec = execution
1325
+
1326
+ break
1327
+
1328
+ if not ctx.current_blessed_exec:
1329
+ print(
1330
+ "Couldn't find blessed execution " +
1331
+ f"{current_blessed_exec_id} !\n" +
1332
+ "It seems that prior blessed execution was " +
1333
+ "executed on another DAG-engine instance.",
1334
+ file=sys.stderr, flush=True)
1335
+
1336
+ print("new : " +
1337
+ str(ctx.perf_metrics[main_perf_metric_name]) +
1338
+ " - previous best : " +
1339
+ str(current_blessed_metric_value) +
1340
+ " - model_version_blessing : " +
1341
+ str(ctx.model_version_blessed))
1342
+
1343
+ else:
1344
+ raise Exception(
1345
+ "Performance metric '" +
1346
+ main_perf_metric_name +
1347
+ "' can't be found in eval results " +
1348
+ "from blessed execution " +
1349
+ str(current_blessed_version_dict[
1350
+ "exec_id"]) + " !")
1351
+
1352
+ # ctx.model_version_blessed = True ### DEBUG - DELETE ###
1353
+
1354
+ return None
1355
+
1356
+
1357
+ @task(ui_css=UiCss(background="#FF9900", color="#111827", border="#1F2937"))
1358
+ def model_to_hub(_) -> None:
1359
+ """
1360
+ Push to hub model version, including
1361
+ readme with versioning info.
1362
+ """
1363
+
1364
+ #############################
1365
+ # case of user-provided #
1366
+ # documentation artifact(s) #
1367
+ #############################
1368
+ # note that user can provide either
1369
+ # 'pipeline_card.py' or 'template.html'
1370
+ # or 'dataset_readme.py'
1371
+ # or 'dataset_readme_template.md'
1372
+ # or 'model_readme.py'
1373
+ # or 'model_readme_template.md'
1374
+ # or any combination of those
1375
+ # when specifying custom
1376
+ # 'pipeline_card_artifacts_path'
1377
+ if (
1378
+ "model_readme_template.md" in
1379
+ os.listdir(ctx.pipeline_card_artifacts_path)
1380
+ ):
1381
+ template_dir = ctx.pipeline_card_artifacts_path
1382
+ else:
1383
+ template_dir = os.path.dirname(
1384
+ importlib.util.find_spec(
1385
+ f"retrain_pipelines.pipeline_card."+
1386
+ f"{os.getenv('retrain_pipeline_type')}"
1387
+ ).origin)
1388
+ print(f"template_dir : '{template_dir}'")
1389
+ #############################
1390
+ if "model_readme.py" in os.listdir(
1391
+ ctx.pipeline_card_artifacts_path):
1392
+ from retrain_pipelines.utils import \
1393
+ get_get_model_readme_content
1394
+ get_model_readme_content = \
1395
+ get_get_model_readme_content(
1396
+ ctx.pipeline_card_artifacts_path)
1397
+ else:
1398
+ from retrain_pipelines.pipeline_card import \
1399
+ get_model_readme_content
1400
+ #############################
1401
+ from retrain_pipelines.model.hf_utils import \
1402
+ push_model_version_to_hub
1403
+
1404
+ #############################
1405
+ # model README #
1406
+ # from template #
1407
+ #############################
1408
+ commit_datetime = datetime.utcnow()
1409
+ new_model_version_label = get_new_repo_minor_version(
1410
+ repo_id=ctx.model_repo_id,
1411
+ repo_type="model",
1412
+ hf_token=os.getenv("HF_TOKEN", None))
1413
+ readme_content = get_model_readme_content(
1414
+ template_folder=template_dir,
1415
+
1416
+ model_repo_id=ctx.model_repo_id,
1417
+
1418
+ base_model_dict=ctx.hf_base_model_dict,
1419
+ training_dataset_dict=ctx.dataset_commit_dict,
1420
+
1421
+ version_label=new_model_version_label,
1422
+ commit_datetime=commit_datetime,
1423
+ perf_metrics=ctx.perf_metrics,
1424
+
1425
+ pipeline_name=ctx.pipeline_name,
1426
+ exec_id=ctx.exec_id
1427
+ )
1428
+ #############################
1429
+
1430
+ print("Pushing model version to HF hub " +
1431
+ ("(blessed). " if ctx.model_version_blessed
1432
+ else "(not blessed). ") +
1433
+ "May take a while..",
1434
+ flush=True)
1435
+ model_commit_hash = push_model_version_to_hub(
1436
+ repo_id=ctx.model_repo_id,
1437
+ model_version_blessed=\
1438
+ ctx.model_version_blessed,
1439
+ version_label=new_model_version_label,
1440
+ timestamp_str=commit_datetime.strftime(
1441
+ "%Y-%m-%d %H:%M:%S UTC"),
1442
+ model_dir=ctx.sft_model_dir,
1443
+ model_readme_content=readme_content,
1444
+ hf_token=os.getenv("HF_TOKEN", None)
1445
+ )
1446
+ if not model_commit_hash:
1447
+ raise Exception(
1448
+ "Failed to publish model version.")
1449
+ print("Push of model version to HF hub completed.",
1450
+ flush=True)
1451
+ print(f"https://huggingface.co/{ctx.model_repo_id}" +
1452
+ f"/blob/{model_commit_hash}/README.md")
1453
+
1454
+ ctx.model_commit_dict = {
1455
+ "repo_id": ctx.model_repo_id,
1456
+ "commit_hash": model_commit_hash,
1457
+ "version_label": new_model_version_label,
1458
+ "commit_datetime": commit_datetime,
1459
+ }
1460
+
1461
+ return None
1462
+
1463
+
1464
+ @task
1465
+ def infra_validator(_) -> None:
1466
+ """
1467
+ If the trained model version is blessed,
1468
+ validate serving.
1469
+ """
1470
+ """
1471
+ Note that using isolated virtual env
1472
+ (using @conda task decorator)
1473
+ is advisable to not embark the whole
1474
+ pipeline dependencies into the local server.
1475
+ We don't for educational purpose,
1476
+ keep things "simple" to grasp
1477
+ as well as to avoid forcing conda
1478
+ (for instance miniconda) as
1479
+ a virtual environment management mean
1480
+ to the user.
1481
+ """
1482
+ """
1483
+ Note : We load base model from HF-cache
1484
+ (mounted as /huggingface_hub_cache
1485
+ docker volume) and adapter from local dir
1486
+ (mounted as /FuncCallAdater docker volume.
1487
+ """
1488
+
1489
+ ctx.local_serve_is_ready = LocalServeReadinessEnum.NOT_APPLICABLE
1490
+
1491
+ if ctx.model_version_blessed:
1492
+ from retrain_pipelines.utils.docker import \
1493
+ env_has_docker
1494
+
1495
+ if env_has_docker():
1496
+ model_module_dir = \
1497
+ os.path.dirname(
1498
+ importlib.util.find_spec(
1499
+ "retrain_pipelines.model." +
1500
+ os.getenv('retrain_pipeline_type')
1501
+ ).origin)
1502
+
1503
+ # server & data-model & server-config modules artifacts
1504
+ files_to_copy = [
1505
+ "litserve_server.py",
1506
+ "litserve_datamodel.py",
1507
+ "litserve_serverconfig.py",
1508
+ ".dockerignore" # docker context loading
1509
+ # at image-build time,
1510
+ # exclude model weights
1511
+ ]
1512
+ for filename in files_to_copy:
1513
+ shutil.copy(
1514
+ os.path.join(model_module_dir, "litserve",
1515
+ filename),
1516
+ os.path.join(ctx.serving_artifacts_local_folder,
1517
+ filename)
1518
+ )
1519
+
1520
+ # save dependencies as artifact
1521
+ create_requirements(ctx.serving_artifacts_local_folder,
1522
+ exclude=["numpy", # version conflict
1523
+ # quick fix
1524
+ "cudf-polars-.*", "cuda-python",
1525
+ "nvidia-.*", "(py)?libcudf-.*",
1526
+ "nvtx", "rmm-.*", "litserve",
1527
+ "protobuf", "grpc.*",
1528
+ "tensorboard",
1529
+ ".*retrain-pipelines.*"]
1530
+ )
1531
+
1532
+ # server config yaml
1533
+ env = Environment(loader=FileSystemLoader(
1534
+ os.path.join(model_module_dir, "litserve")))
1535
+ template = env.get_template(
1536
+ "litserve_serverconfig_template.yaml")
1537
+ server_config_data = {
1538
+ "port": "8000",
1539
+ "max_seq_length": ctx.max_seq_length,
1540
+ "max_new_token": ctx.max_new_tokens,
1541
+ "base_model": {
1542
+ "repo_id": ctx.hf_base_model_dict["repo_id"],
1543
+ "revision": ctx.hf_base_model_dict["commit_hash"]
1544
+ },
1545
+ "adapters": [
1546
+ {
1547
+ "name": "func_caller",
1548
+ "path": "/FuncCallAdapter"
1549
+ }
1550
+ ]
1551
+ }
1552
+ server_config_yaml = template.render(server_config_data)
1553
+ print(server_config_yaml)
1554
+ with open(os.path.join(
1555
+ ctx.serving_artifacts_local_folder,
1556
+ "litserve_serverconfig.yaml"), 'w'
1557
+ ) as output_file:
1558
+ output_file.write(server_config_yaml)
1559
+
1560
+ # Dockerfile
1561
+ env = Environment(loader=FileSystemLoader(
1562
+ os.path.join(model_module_dir)))
1563
+ template = env.get_template(
1564
+ "Dockerfile.litserve_template")
1565
+ # Change CUDA version here from available list
1566
+ # @see https://hub.docker.com/r/nvidia/cuda/tags
1567
+ dockerfile_content = template.render(
1568
+ {"cuda_version": "12.0.0"})
1569
+ with open(os.path.join(
1570
+ ctx.serving_artifacts_local_folder,
1571
+ "Dockerfile.litserve"), 'w'
1572
+ ) as output_file:
1573
+ output_file.write(dockerfile_content)
1574
+
1575
+ os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
1576
+
1577
+ ############################################
1578
+ # actually deploy the inference service #
1579
+ ############################################
1580
+ start_time = time.time()
1581
+ from retrain_pipelines.utils.docker import \
1582
+ build_and_run_docker, print_container_log_tail, \
1583
+ cleanup_docker
1584
+ from retrain_pipelines.model.litserve import \
1585
+ endpoint_started, endpoint_is_ready
1586
+
1587
+ ctx.port = 8765
1588
+ HF_HUB_CACHE = os.path.realpath(os.path.expanduser(
1589
+ os.getenv(
1590
+ "HF_HUB_CACHE",
1591
+ os.path.join(os.getenv("HF_HOME",
1592
+ "~/.cache/huggingface"),
1593
+ "hub")
1594
+ )))
1595
+ print(f"HF_HUB_CACHE : {HF_HUB_CACHE}")
1596
+ image_name = container_name = "litserve-model"
1597
+
1598
+ serving_container = build_and_run_docker(
1599
+ image_name=image_name, image_tag="1.0",
1600
+ build_path=ctx.serving_artifacts_local_folder,
1601
+ dockerfile="Dockerfile.litserve",
1602
+ ports_publish_dict={'8000/tcp': ctx.port},
1603
+ env_vars_dict={
1604
+ "HF_HUB_CACHE": "/huggingface_hub_cache",
1605
+ "HF_TOKEN": os.getenv("HF_TOKEN")
1606
+ },
1607
+ volumes_dict={
1608
+ ctx.sft_model_dir:
1609
+ {"bind": "/FuncCallAdapter",
1610
+ "mode": "ro"},
1611
+ HF_HUB_CACHE:
1612
+ {"bind": "/huggingface_hub_cache",
1613
+ "mode": "ro"}
1614
+ }
1615
+ )
1616
+
1617
+ if not serving_container:
1618
+ print("failed spinning the LitServe container",
1619
+ file=sys.stderr)
1620
+ ctx.local_serve_is_ready = \
1621
+ LocalServeReadinessEnum.FAILURE
1622
+ try:
1623
+ cleanup_docker(
1624
+ container_name=container_name,
1625
+ image_name=f"{image_name}:1.0",
1626
+ no_pruning=True # for intermediate layers recycling
1627
+ # (during later re-runs)
1628
+ # to avoid long rebuild time
1629
+ # of exactly the same.
1630
+ )
1631
+ except Exception as cleanup_ex:
1632
+ # fail silently
1633
+ pass
1634
+ else:
1635
+ print("Awaiting endpoint launch..")
1636
+ start_time = time.time()
1637
+ if not endpoint_started(
1638
+ container_name, port=ctx.port, timeout=10*60
1639
+ ):
1640
+ print(
1641
+ f"The endpoint '{container_name}' " +
1642
+ f"did not start.")
1643
+ ctx.local_serve_is_ready = \
1644
+ LocalServeReadinessEnum.FAILURE
1645
+ # health check on the spun-up endpoint
1646
+ elif endpoint_is_ready(port=ctx.port):
1647
+ ctx.local_serve_is_ready = \
1648
+ LocalServeReadinessEnum.SUCCESS
1649
+ elapsed_time = time.time() - start_time
1650
+ print("deploy_local - Elapsed time: " +
1651
+ f"{elapsed_time:.2f} seconds")
1652
+ ############################################
1653
+ else:
1654
+ # env doesn't have docker
1655
+ ctx.local_serve_is_ready = \
1656
+ LocalServeReadinessEnum.FAILURE_NO_DOCKER
1657
+
1658
+ if LocalServeReadinessEnum.SUCCESS == ctx.local_serve_is_ready:
1659
+ from retrain_pipelines.model.litserve.litserve_datamodel \
1660
+ import Response
1661
+
1662
+ import requests
1663
+
1664
+ url = f"http://localhost:{ctx.port}/predict"
1665
+ headers = {"accept": "application/x-www-form-urlencoded"}
1666
+
1667
+ try:
1668
+ start_time = time.time()
1669
+ data = {
1670
+ "adapter_name": "func_caller",
1671
+ "queries_list": '["Hello.", "Is 49 a perfect square?"]'
1672
+ }
1673
+ print(f"inference test - data: {data}")
1674
+ response = requests.post(url, headers=headers, data=data)
1675
+ parsed_response = Response(**{"output": response.json()})
1676
+ elapsed_time = time.time() - start_time
1677
+ print("parsed_response ('func_caller' adapter ON) :" +
1678
+ str(parsed_response) +
1679
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1680
+
1681
+ start_time = time.time()
1682
+ data = {
1683
+ "queries_list": '["Hello.", "Is 49 a perfect square?"]'
1684
+ }
1685
+ print(f"inference test - data: {data}")
1686
+ response = requests.post(url, headers=headers, data=data)
1687
+ parsed_response = Response(**{"output": response.json()})
1688
+ elapsed_time = time.time() - start_time
1689
+ print(f"parsed_response (no adapter) : {parsed_response}" +
1690
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1691
+
1692
+ except Exception as ex:
1693
+ print(ex, file=sys.stderr)
1694
+ traceback.print_tb(ex.__traceback__, file=sys.stderr)
1695
+ ctx.local_serve_is_ready = \
1696
+ LocalServeReadinessEnum.FAILURE
1697
+ pass
1698
+
1699
+ try:
1700
+ cleanup_docker(
1701
+ container_name=container_name,
1702
+ image_name=f"{image_name}:1.0",
1703
+ no_pruning=True # for intermediate layers recycling
1704
+ # (during later re-runs)
1705
+ # to avoid long rebuild time
1706
+ # of exactly the same.
1707
+ )
1708
+ except Exception as cleanup_ex:
1709
+ # fail silently
1710
+ pass
1711
+
1712
+ return None
1713
+
1714
+
1715
+ @task
1716
+ def pipeline_card(_, task_id: int) -> None:
1717
+ #############################
1718
+ # case of user-provided #
1719
+ # documentation artifact(s) #
1720
+ #############################
1721
+ # note that user can provide either
1722
+ # 'pipeline_card.py' or 'template.html'
1723
+ # or 'dataset_readme.py'
1724
+ # or 'dataset_readme_template.md'
1725
+ # or 'model_readme.py'
1726
+ # or 'model_readme_template.md'
1727
+ # or any combination of those
1728
+ # when specifying custom
1729
+ # 'pipeline_card_artifacts_path'
1730
+ if "template.html" in os.listdir(
1731
+ ctx.pipeline_card_artifacts_path
1732
+ ):
1733
+ template_dir = ctx.pipeline_card_artifacts_path
1734
+ else:
1735
+ template_dir = os.path.dirname(
1736
+ importlib.util.find_spec(
1737
+ f"retrain_pipelines.pipeline_card."+
1738
+ f"{os.getenv('retrain_pipeline_type')}"
1739
+ ).origin)
1740
+ #############################
1741
+ if "pipeline_card.py" in os.listdir(
1742
+ ctx.pipeline_card_artifacts_path
1743
+ ):
1744
+ from retrain_pipelines.utils import get_get_html
1745
+ get_html = \
1746
+ get_get_html(ctx.pipeline_card_artifacts_path)
1747
+ else:
1748
+ from retrain_pipelines.pipeline_card import \
1749
+ get_html
1750
+ from retrain_pipelines.dag_engine.renderer import dag_svg
1751
+ #############################
1752
+
1753
+ #############################
1754
+ ## html "custom" card ##
1755
+ #############################
1756
+ dt = datetime.now(tz=timezone.utc)
1757
+ formatted_dt = dt.strftime("%A %b %d %Y %I:%M:%S %p %Z")
1758
+ task_obj_python_cmd = f"sdk.Task({task_id})"
1759
+ executions_count = ExecutionsIterator(
1760
+ exec_name=ctx.pipeline_name).length()
1761
+
1762
+ params={
1763
+ 'template_dir': template_dir,
1764
+ 'title': ctx.pipeline_name,
1765
+ "subtitle": f"(Pipeline execution # {executions_count}," + \
1766
+ f" exec_id: {str(ctx.exec_id)} - {formatted_dt})",
1767
+
1768
+ # blessed status / current_blessed version
1769
+ 'model_version_blessed': ctx.model_version_blessed,
1770
+ 'current_blessed_version_label': (
1771
+ ctx.current_blessed_version_dict["version_label"]
1772
+ if ctx.current_blessed_version_dict
1773
+ else None
1774
+ ),
1775
+ 'current_blessed_commit_datetime': (
1776
+ ctx.current_blessed_version_dict["commit_datetime"]
1777
+ if ctx.current_blessed_version_dict
1778
+ else None
1779
+ ),
1780
+ 'current_blessed_model_commit_hash': (
1781
+ ctx.current_blessed_version_dict["commit_hash"]
1782
+ if ctx.current_blessed_version_dict
1783
+ else None
1784
+ ),
1785
+ 'current_blessed_run': ctx.current_blessed_run,
1786
+
1787
+ 'LocalServeReadinessEnum': LocalServeReadinessEnum,
1788
+ 'local_serve_is_ready': ctx.local_serve_is_ready,
1789
+ # EDA
1790
+ 'main_dataset_repo_id': ctx.hf_dataset['repo_id'],
1791
+ 'main_dataset_commit_hash': ctx.hf_dataset_dict['commit_hash'],
1792
+ 'main_dataset_commit_datetime': \
1793
+ ctx.hf_dataset_dict['commit_datetime'],
1794
+
1795
+ 'records_count': ctx.records_count,
1796
+ 'data_schema': ctx.data_schema,
1797
+ 'answers_tools_count_fig': ctx.answers_tools_count_fig,
1798
+ 'words_count_fig': ctx.words_count_fig,
1799
+
1800
+ # model training
1801
+ 'dataset_repo_id': ctx.dataset_repo_id,
1802
+ 'dataset_version_label': ctx.dataset_commit_dict["version_label"],
1803
+ 'dataset_commit_datetime': ctx.dataset_commit_dict["commit_datetime"],
1804
+ 'dataset_commit_hash': ctx.dataset_commit_dict["commit_hash"],
1805
+ 'dataset_augmentation_rate': ctx.actual_augmentation_rate,
1806
+ 'dataset_enrichment_rate': ctx.enrichment_rate,
1807
+
1808
+ # trained model version
1809
+ 'model_repo_id': ctx.model_repo_id,
1810
+ 'model_version_label': ctx.model_commit_dict["version_label"],
1811
+ 'model_commit_datetime': ctx.model_commit_dict["commit_datetime"],
1812
+ 'model_commit_hash': ctx.model_commit_dict["commit_hash"],
1813
+
1814
+ 'cpt_log_history_fig': ctx.cpt_log_history_fig,
1815
+ 'sft_log_history_fig': ctx.sft_log_history_fig,
1816
+
1817
+ 'validation_completions_fig': ctx.validation_completions_fig,
1818
+
1819
+ 'hf_base_model_dict': ctx.hf_base_model_dict,
1820
+ 'pipeline_parameters_dict': {"cpt": ctx.cpt_training_args,
1821
+ "sft": ctx.sft_training_args},
1822
+
1823
+ 'metrics_dict': ctx.perf_metrics,
1824
+
1825
+ 'task_obj_python_cmd': task_obj_python_cmd,
1826
+ 'dag_svg': dag_svg(execution_id=ctx.exec_id)
1827
+ }
1828
+ html = get_html(params)
1829
+
1830
+ filename = os.path.join(
1831
+ os.environ["RP_ARTIFACTS_STORE"],
1832
+ ctx.pipeline_name, str(ctx.exec_id),
1833
+ "pipeline_card.html"
1834
+ )
1835
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
1836
+ with open(filename, "w", encoding="utf-8") as file:
1837
+ file.write(html)
1838
+ logger.debug(
1839
+ "pipeline_card - " +
1840
+ f"[bold]pipeline_card_file_fullname : {filename}[/]")
1841
+
1842
+ ctx.pipeline_card_file_fullname = filename
1843
+ #############################
1844
+
1845
+ return None
1846
+
1847
+
1848
+ @task(ui_css=UiCss(background="#FF9900", color="#111827", border="#1F2937"))
1849
+ def pipeline_to_hub(_) -> None:
1850
+ """
1851
+ publish versioned source-code and pipeline-card
1852
+ for ths run on the Hugging Face Hub.
1853
+ """
1854
+ model_commit_datetime = \
1855
+ ctx.model_commit_dict["commit_datetime"]
1856
+ timestamp_str = \
1857
+ "{:%Y%m%d_%H%M%S}".format(model_commit_datetime) + \
1858
+ "{:03d}".format(model_commit_datetime.microsecond//1000) + \
1859
+ "_UTC"
1860
+ subfolder_name = \
1861
+ "v" + ctx.model_commit_dict["version_label"] + \
1862
+ "_" + timestamp_str
1863
+ commit_datetime = datetime.utcnow()
1864
+
1865
+ ###############################
1866
+ # source-code #
1867
+ ###############################
1868
+ # We upload only herein file #
1869
+ # plus user-provided versions #
1870
+ # of the customizable ones #
1871
+ # (if any). #
1872
+ ###############################
1873
+ custom_source_files = [os.path.abspath(__file__)]
1874
+ if (
1875
+ ctx.pipeline_card_artifacts_path != \
1876
+ ctx.params_definitions["pipeline_card_artifacts_path"].default
1877
+ ):
1878
+ candidate_source_files = [
1879
+ "pipeline_card.py",
1880
+ "template.html",
1881
+ "dataset_readme.py",
1882
+ "dataset_readme_template.md",
1883
+ "model_readme.py",
1884
+ "model_readme_template.md"
1885
+ ]
1886
+ for candidate_source_file in candidate_source_files:
1887
+ file_fullpath = os.path.join(
1888
+ ctx.pipeline_card_artifacts_path,
1889
+ candidate_source_file)
1890
+ if os.path.exists(file_fullpath):
1891
+ custom_source_files.append(file_fullpath)
1892
+
1893
+ source_code_commit_hash = \
1894
+ push_files_to_hub_repo_branch(
1895
+ repo_id=ctx.model_repo_id,
1896
+ branch_name="retrain-pipelines_source-code",
1897
+ file_fullnames=custom_source_files,
1898
+ include_requirements_txt=True,
1899
+ path_in_repo=subfolder_name,
1900
+ commit_message=\
1901
+ "source-code for model version " + \
1902
+ subfolder_name + \
1903
+ f"- retrain-pipelines {__version__}",
1904
+ repo_type="model",
1905
+ hf_token=os.getenv("HF_TOKEN", None)
1906
+ )
1907
+ print(source_code_commit_hash)
1908
+ ctx.source_code_commit_dict = {
1909
+ "repo_id": ctx.model_repo_id,
1910
+ "branch_name": "retrain-pipelines_source-code",
1911
+ "commit_datetime": commit_datetime,
1912
+ "commit_hash": source_code_commit_hash
1913
+ }
1914
+ ###############################
1915
+
1916
+ ###############################
1917
+ # pipeline-card #
1918
+ ###############################
1919
+ pipeline_card_commit_hash = \
1920
+ push_files_to_hub_repo_branch(
1921
+ repo_id=ctx.model_repo_id,
1922
+ branch_name="retrain-pipelines_pipeline-card",
1923
+ file_fullnames=[ctx.pipeline_card_file_fullname],
1924
+ path_in_repo=subfolder_name,
1925
+ commit_message=\
1926
+ "pipeline-card for model version " + \
1927
+ subfolder_name + \
1928
+ f"- retrain-pipelines {__version__}",
1929
+ repo_type="model",
1930
+ hf_token=os.getenv("HF_TOKEN", None)
1931
+ )
1932
+ print(pipeline_card_commit_hash)
1933
+ ctx.pipeline_card_commit_dict = {
1934
+ "repo_id": ctx.model_repo_id,
1935
+ "branch_name": "retrain-pipelines_pipeline-card",
1936
+ "commit_datetime": commit_datetime,
1937
+ "commit_hash": pipeline_card_commit_hash
1938
+ }
1939
+ ###############################
1940
+
1941
+ return None
1942
+
1943
+
1944
+ @task
1945
+ def deploy(_):
1946
+ """
1947
+ placeholder for the serving SDK deploy call
1948
+ (on the target production platform).
1949
+ consider including the portable pipelione-card itself !
1950
+ """
1951
+
1952
+ if ctx.model_version_blessed and (ctx.local_serve_is_ready == 1):
1953
+ pass # your code here
1954
+
1955
+ return None
1956
+
1957
+
1958
+ @task
1959
+ def load_test(_):
1960
+ """
1961
+ placeholder
1962
+ """
1963
+
1964
+ if ctx.model_version_blessed and (ctx.local_serve_is_ready == 1):
1965
+ pass # your code here
1966
+
1967
+ return None
1968
+
1969
+
1970
+ @task
1971
+ def end(_):
1972
+ pass
1973
+
1974
+
1975
+ #--- retraining-pipeline params & DAG ---------------------------------------------------
1976
+
1977
+
1978
+ @dag(ui_css=UiCss(color="#FFDD00", background="#7AD4FF", border="#C28E00"))
1979
+ def retrain_pipeline():
1980
+ """
1981
+ Retraining pipeline with SFT & CPT. Small LLM with pluggable adapter specialized in tool-calling from intrinsic knowledge bank of tools and not from extended context. Model-version blessing. Serving via a custom LitServe toy-server.
1982
+ """
1983
+ # @see https://github.com/unslothai/unsloth/wiki
1984
+
1985
+ #--- flow parameters -------------------------------------------------------
1986
+
1987
+
1988
+ RETRAIN_PIPELINE_TYPE = "mf_unsloth_func_call_litserve"
1989
+ # best way to share the config across subprocesses
1990
+ os.environ["retrain_pipeline_type"] = RETRAIN_PIPELINE_TYPE
1991
+
1992
+ hf_dataset = DagParam(
1993
+ description="dict with 'repo_id' and 'commit_hash' keys. " + \
1994
+ "if 'commit_hash is None, falls back to latest version " +\
1995
+ "of the dataset available in parquet format.\n" +
1996
+ "Note that there are 3 required 'attributes' of type " + \
1997
+ "str, list[str], list[str]",
1998
+ default=dedent("""{
1999
+ "repo_id": "Salesforce/xlam-function-calling-60k",
2000
+ "config_name": "",
2001
+ "commit_hash": "",
2002
+ "attributes": {
2003
+ "query_attr": "query",
2004
+ "answers_attr": "answers",
2005
+ "tools_attr": "tools"
2006
+ }
2007
+ }""")
2008
+ )
2009
+
2010
+ augmentation_rate = DagParam(
2011
+ description="(float) proportion of records to be augmented " + \
2012
+ "(x% of original dataset is created" + \
2013
+ " as additional augmented datapoints), i.e. " + \
2014
+ "truncated queries to serve as negative examples, " + \
2015
+ "meaning they trigger no tool call " + \
2016
+ "due to info incompleteness.",
2017
+ default=.05
2018
+ )
2019
+
2020
+ hf_enrich_dataset = DagParam(
2021
+ description="dict with 'repo_id', 'config_name' and 'commit_hash', " + \
2022
+ "query_attribute' and 'query_attribute_handler' keys. " + \
2023
+ "if 'commit_hash is None, falls back to latest version " + \
2024
+ "of the dataset available in parquet format." + \
2025
+ "'query_attribute' depicts the dataset attribute " + \
2026
+ "from which 'queries' are to be sampled." + \
2027
+ "'query_attribute_handler' serves for attributes " + \
2028
+ "that have complex structure, " + \
2029
+ "other than 'string' datatype.",
2030
+ # @see https://huggingface.co/datasets/google-research-datasets/natural_questions
2031
+ default=dedent("""{
2032
+ "repo_id": "lighteval/natural_questions_clean",
2033
+ "config_name": "",
2034
+ "commit_hash": "",
2035
+ "query_attribute": "question",
2036
+ "query_attribute_handler": "lambda x: x"
2037
+ }""")
2038
+ )
2039
+
2040
+ enrichment_rate = DagParam(
2041
+ description="(float) proportion of records " + \
2042
+ "to be added from the 'hf_enrich_dataset'" + \
2043
+ "(x% of original dataset is sampled and" + \
2044
+ " added as enriching datapoints), i.e. " + \
2045
+ "queries to serve as negative examples, " + \
2046
+ "due to their complete disconnexion " + \
2047
+ "to tool calling situations.",
2048
+ default=.1
2049
+ )
2050
+
2051
+ dataset_repo_id = DagParam(
2052
+ description="(str) The 'repo_id' to be used " + \
2053
+ "for the Hugging Face dataset version push " + \
2054
+ "(will be created at runtime" + \
2055
+ " if doesn't already exist).",
2056
+ default="retrain-pipelines/func_calls"
2057
+ )
2058
+
2059
+ polars_engine = DagParam(
2060
+ description="The engine used by Polars for " + \
2061
+ "dataset querying and processing " + \
2062
+ "(either 'gpu' or 'cpu').",
2063
+ default="gpu"
2064
+ )
2065
+
2066
+ hf_base_model = DagParam(
2067
+ description="(str) dict with 'repo_id' and 'commit_hash' keys." + \
2068
+ "if 'commit_hash is None, falls back " + \
2069
+ "to latest available version of the model.",
2070
+ default=dedent("""{
2071
+ "repo_id": "unsloth/Qwen2.5-1.5B",
2072
+ "commit_hash": ""
2073
+ }""")
2074
+ )
2075
+
2076
+ cpt_training_args = DagParam(
2077
+ description="dict with `TrainingArguments` params " + \
2078
+ "for the CPT job.",
2079
+ default=dedent("""{
2080
+ "warmup_ratio": 0.1,
2081
+ "num_train_epochs": 1
2082
+ }""")
2083
+ )
2084
+
2085
+ sft_training_args = DagParam(
2086
+ description="dict with `TrainingArguments` params " + \
2087
+ "for the SFT job.",
2088
+ default=dedent("""{
2089
+ "warmup_ratio": 0.1,
2090
+ "num_train_epochs": 1
2091
+ }""")
2092
+ )
2093
+
2094
+ model_repo_id = DagParam(
2095
+ description="(str) The 'repo_id' to be used " + \
2096
+ "for the Hugging Face model version push " + \
2097
+ "(will be created at runtime" + \
2098
+ " if doesn't already exist).",
2099
+ default="retrain-pipelines/function_caller"
2100
+ )
2101
+
2102
+ default_pipeline_card_module_dir = \
2103
+ os.path.dirname(
2104
+ importlib.util.find_spec(
2105
+ f"retrain_pipelines.pipeline_card."+
2106
+ f"{RETRAIN_PIPELINE_TYPE}"
2107
+ ).origin)
2108
+ pipeline_card_artifacts_path = DagParam(
2109
+ description="pipeline_card artifacts location " + \
2110
+ "(i.e. dir hosting your optional " + \
2111
+ " custom documentation files :" + \
2112
+ " 'pipeline_card.py' and/or 'template.html'" + \
2113
+ " and/or 'model_readme.py'"+\
2114
+ " and/or 'model_readme_template.md'," + \
2115
+ " and/or 'dataset_readme.py'" + \
2116
+ " and/or 'dataset_readme_template.md' file), " + \
2117
+ "if different from default.",
2118
+ default=default_pipeline_card_module_dir
2119
+ )
2120
+ # TODO - convert from class method to TBD
2121
+ # @staticmethod
2122
+ # def copy_default_dataset_readme_module(
2123
+ # target_dir: str,
2124
+ # exists_ok: bool = False
2125
+ # ) -> None:
2126
+ # os.makedirs(target_dir, exist_ok=True)
2127
+ # if (
2128
+ # not exists_ok and
2129
+ # os.path.exists(os.path.join(target_dir, "dataset_readme.py"))
2130
+ # ):
2131
+ # print("File already exists. Skipping copy.")
2132
+ # else:
2133
+ # filefullname = os.path.join(
2134
+ # default_pipeline_card_module_dir,
2135
+ # "dataset_readme.py"
2136
+ # )
2137
+ # shutil.copy(filefullname, target_dir)
2138
+ # print(filefullname)
2139
+ # TODO - convert from class method to TBD
2140
+ # @staticmethod
2141
+ # def copy_default_dataset_readme_template(
2142
+ # target_dir: str,
2143
+ # exists_ok: bool = False
2144
+ # ) -> None:
2145
+ # os.makedirs(target_dir, exist_ok=True)
2146
+ # if (
2147
+ # not exists_ok and
2148
+ # os.path.exists(os.path.join(target_dir,
2149
+ # "dataset_readme_template.md"))
2150
+ # ):
2151
+ # print("File already exists. Skipping copy.")
2152
+ # else:
2153
+ # filefullname = os.path.join(
2154
+ # default_pipeline_card_module_dir,
2155
+ # "dataset_readme_template.md")
2156
+ # shutil.copy(filefullname, target_dir)
2157
+ # print(filefullname)
2158
+ # TODO - convert from class method to TBD
2159
+ # @staticmethod
2160
+ # def copy_default_model_readme_module(
2161
+ # target_dir: str,
2162
+ # exists_ok: bool = False
2163
+ # ) -> None:
2164
+ # os.makedirs(target_dir, exist_ok=True)
2165
+ # if (
2166
+ # not exists_ok and
2167
+ # os.path.exists(os.path.join(target_dir, "model_readme.py"))
2168
+ # ):
2169
+ # print("File already exists. Skipping copy.")
2170
+ # else:
2171
+ # filefullname = os.path.join(
2172
+ # default_pipeline_card_module_dir,
2173
+ # "model_readme.py"
2174
+ # )
2175
+ # shutil.copy(filefullname, target_dir)
2176
+ # print(filefullname)
2177
+ # TODO - convert from class method to TBD
2178
+ # @staticmethod
2179
+ # def copy_default_model_readme_template(
2180
+ # target_dir: str,
2181
+ # exists_ok: bool = False
2182
+ # ) -> None:
2183
+ # os.makedirs(target_dir, exist_ok=True)
2184
+ # if (
2185
+ # not exists_ok and
2186
+ # os.path.exists(os.path.join(target_dir,
2187
+ # "model_readme_template.md"))
2188
+ # ):
2189
+ # print("File already exists. Skipping copy.")
2190
+ # else:
2191
+ # filefullname = os.path.join(
2192
+ # default_pipeline_card_module_dir,
2193
+ # "model_readme_template.md")
2194
+ # shutil.copy(filefullname, target_dir)
2195
+ # print(filefullname)
2196
+ # TODO - convert from class method to TBD
2197
+ # @staticmethod
2198
+ # def copy_default_pipeline_card_module(
2199
+ # target_dir: str,
2200
+ # exists_ok: bool = False
2201
+ # ) -> None:
2202
+ # os.makedirs(target_dir, exist_ok=True)
2203
+ # if (
2204
+ # not exists_ok and
2205
+ # os.path.exists(os.path.join(target_dir, "pipeline_card.py"))
2206
+ # ):
2207
+ # print("File already exists. Skipping copy.")
2208
+ # else:
2209
+ # filefullname = os.path.join(
2210
+ # default_pipeline_card_module_dir,
2211
+ # "pipeline_card.py"
2212
+ # )
2213
+ # shutil.copy(filefullname, target_dir)
2214
+ # print(filefullname)
2215
+ # TODO - convert from class method to TBD
2216
+ # @staticmethod
2217
+ # def copy_default_pipeline_card_html_template(
2218
+ # target_dir: str,
2219
+ # exists_ok: bool = False
2220
+ # ) -> None:
2221
+ # os.makedirs(target_dir, exist_ok=True)
2222
+ # if (
2223
+ # not exists_ok and
2224
+ # os.path.exists(os.path.join(target_dir, "template.html"))
2225
+ # ):
2226
+ # print("File already exists. Skipping copy.")
2227
+ # else:
2228
+ # filefullname = os.path.join(
2229
+ # default_pipeline_card_module_dir,
2230
+ # "template.html")
2231
+ # shutil.copy(filefullname, target_dir)
2232
+ # print(filefullname)
2233
+
2234
+ del RETRAIN_PIPELINE_TYPE
2235
+
2236
+ #---------------------------------------------------------------------------
2237
+
2238
+ return start >> eda \
2239
+ >> augment_data >> enrich_data >> dataset_to_hub \
2240
+ >> continued_pre_training >> supervised_finetuning \
2241
+ >> evaluate_model >> model_version_blessing \
2242
+ >> model_to_hub >> infra_validator >> pipeline_card \
2243
+ >> pipeline_to_hub >> deploy >> load_test >> end
2244
+