Aurelien-Morgan-Bot commited on
Commit
0f30bbe
·
verified ·
1 Parent(s): 495f634

source-code for model version v0.37_20260227_192656740_UTC- retrain-pipelines 0.1.2

Browse files
v0.37_20260227_192656740_UTC/requirements.txt ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.4.0
2
+ accelerate==1.1.1
3
+ access==1.1.10.post3
4
+ affine==2.4.0
5
+ aiofiles==24.1.0
6
+ aiohappyeyeballs==2.6.1
7
+ aiohttp==3.13.3
8
+ aiosignal==1.4.0
9
+ aiosqlite==0.22.1
10
+ alabaster==1.0.0
11
+ albucore==0.0.24
12
+ albumentations==2.0.8
13
+ ale-py==0.11.2
14
+ alembic==1.18.4
15
+ altair==5.5.0
16
+ annotated-doc==0.0.4
17
+ annotated-types==0.7.0
18
+ antlr4-python3-runtime==4.9.3
19
+ anyio==4.12.1
20
+ anywidget==0.9.21
21
+ apsw==3.51.2.0
22
+ apswutils==0.1.2
23
+ argon2-cffi==25.1.0
24
+ argon2-cffi-bindings==25.1.0
25
+ array_record==0.8.3
26
+ arrow==1.4.0
27
+ arviz==0.22.0
28
+ astropy==7.2.0
29
+ astropy-iers-data==0.2026.2.23.0.48.33
30
+ asttokens==3.0.1
31
+ astunparse==1.6.3
32
+ atpublic==5.1
33
+ attrs==25.4.0
34
+ audioread==3.1.0
35
+ Authlib==1.6.8
36
+ autograd==1.8.0
37
+ babel==2.18.0
38
+ backcall==0.2.0
39
+ beartype==0.22.9
40
+ beautifulsoup4==4.13.5
41
+ betterproto==2.0.0b6
42
+ bigframes==2.35.0
43
+ bigquery-magics==0.10.3
44
+ bitsandbytes==0.44.1
45
+ bleach==6.3.0
46
+ blinker==1.9.0
47
+ blis==1.3.3
48
+ blobfile==3.2.0
49
+ blosc2==4.0.0
50
+ bokeh==3.8.2
51
+ boto3==1.42.58
52
+ botocore==1.42.58
53
+ Bottleneck==1.4.2
54
+ bqplot==0.12.45
55
+ branca==0.8.2
56
+ brotli==1.2.0
57
+ CacheControl==0.14.4
58
+ cachetools==6.2.6
59
+ catalogue==2.0.10
60
+ certifi==2026.1.4
61
+ cffi==2.0.0
62
+ chardet==5.2.0
63
+ charset-normalizer==3.4.4
64
+ clarabel==0.11.1
65
+ click==8.3.1
66
+ click-plugins==1.1.1.2
67
+ cligj==0.7.2
68
+ cloudpathlib==0.23.0
69
+ cloudpickle==3.1.2
70
+ cmake==3.31.10
71
+ cmdstanpy==1.3.0
72
+ colorcet==3.1.0
73
+ colorlover==0.3.0
74
+ colour==0.1.5
75
+ comm==0.2.3
76
+ community==1.0.0b1
77
+ confection==0.1.5
78
+ cons==0.4.7
79
+ contourpy==1.3.3
80
+ cramjam==2.11.0
81
+ cryptography==43.0.3
82
+ cucim-cu12 @ https://pypi.nvidia.com/cucim-cu12/cucim_cu12-26.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
83
+ cuda-bindings==12.9.4
84
+ cuda-core==0.3.2
85
+ cuda-pathfinder==1.3.5
86
+ cuda-python==12.9.4
87
+ cuda-toolkit==12.8.1
88
+ cudf-cu12==26.2.1
89
+ cudf-polars-cu12==26.2.1
90
+ cufflinks==0.17.3
91
+ cuml-cu12==26.2.0
92
+ cupy-cuda12x==14.0.1
93
+ curl_cffi==0.14.0
94
+ cuvs-cu12 @ https://pypi.nvidia.com/cuvs-cu12/cuvs_cu12-26.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
95
+ cvxopt==1.3.2
96
+ cvxpy==1.6.7
97
+ cycler==0.12.1
98
+ cyipopt==1.5.0
99
+ cymem==2.0.13
100
+ Cython==3.0.12
101
+ dask==2026.1.1
102
+ dask-cuda==26.2.0
103
+ dask-cudf-cu12==26.2.1
104
+ dataproc-spark-connect==1.0.2
105
+ datasets==4.0.0
106
+ db-dtypes==1.5.0
107
+ dbus-python==1.2.18
108
+ debugpy==1.8.15
109
+ decorator==5.2.1
110
+ defusedxml==0.7.1
111
+ deprecation==2.1.0
112
+ diffusers==0.36.0
113
+ dill==0.3.8
114
+ distributed==2026.1.1
115
+ distributed-ucxx-cu12==0.48.0
116
+ distro==1.9.0
117
+ dlib==19.24.6
118
+ dm-tree==0.1.9
119
+ docstring_parser==0.17.0
120
+ docutils==0.21.2
121
+ dopamine_rl==4.1.2
122
+ duckdb==1.3.2
123
+ earthengine-api==1.5.24
124
+ easydict==1.13
125
+ editdistance==0.8.1
126
+ eerepr==0.1.2
127
+ einops==0.8.2
128
+ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
129
+ entrypoints==0.4
130
+ esda==2.8.1
131
+ et_xmlfile==2.0.0
132
+ etils==1.13.0
133
+ etuples==0.3.10
134
+ executing==2.2.1
135
+ Farama-Notifications==0.0.4
136
+ fastai==2.8.7
137
+ fastapi==0.133.0
138
+ fastcore==1.12.16
139
+ fastdownload==0.0.7
140
+ fastjsonschema==2.21.2
141
+ fastlite==0.2.4
142
+ fastprogress==1.1.5
143
+ fasttransform==0.0.2
144
+ ffmpy==1.0.0
145
+ filelock==3.24.3
146
+ fiona==1.10.1
147
+ firebase-admin==6.9.0
148
+ Flask==3.1.3
149
+ flatbuffers==25.12.19
150
+ flax==0.11.2
151
+ folium==0.20.0
152
+ fonttools==4.61.1
153
+ fqdn==1.5.1
154
+ frozendict==2.4.7
155
+ frozenlist==1.8.0
156
+ fsspec==2025.3.0
157
+ future==1.0.0
158
+ gast==0.7.0
159
+ gcsfs==2025.3.0
160
+ GDAL==3.8.4
161
+ gdown==5.2.1
162
+ geemap==0.35.3
163
+ geocoder==1.38.1
164
+ geographiclib==2.1
165
+ geopandas==1.1.2
166
+ geopy==2.4.1
167
+ giddy==2.3.8
168
+ gin-config==0.5.0
169
+ gitdb==4.0.12
170
+ GitPython==3.1.46
171
+ glob2==0.7
172
+ google==3.0.0
173
+ google-adk==1.25.1
174
+ google-ai-generativelanguage==0.6.15
175
+ google-api-core==2.30.0
176
+ google-api-python-client==2.190.0
177
+ google-auth==2.47.0
178
+ google-auth-httplib2==0.3.0
179
+ google-auth-oauthlib==1.2.4
180
+ google-cloud-aiplatform==1.138.0
181
+ google-cloud-appengine-logging==1.8.0
182
+ google-cloud-audit-log==0.4.0
183
+ google-cloud-bigquery==3.40.1
184
+ google-cloud-bigquery-connection==1.20.0
185
+ google-cloud-bigquery-storage==2.36.2
186
+ google-cloud-bigtable==2.35.0
187
+ google-cloud-core==2.5.0
188
+ google-cloud-dataproc==5.25.0
189
+ google-cloud-datastore==2.23.0
190
+ google-cloud-discoveryengine==0.13.12
191
+ google-cloud-firestore==2.23.0
192
+ google-cloud-functions==1.22.0
193
+ google-cloud-iam==2.21.0
194
+ google-cloud-language==2.19.0
195
+ google-cloud-logging==3.13.0
196
+ google-cloud-monitoring==2.29.1
197
+ google-cloud-pubsub==2.35.0
198
+ google-cloud-resource-manager==1.16.0
199
+ google-cloud-secret-manager==2.26.0
200
+ google-cloud-spanner==3.63.0
201
+ google-cloud-speech==2.36.1
202
+ google-cloud-storage==3.9.0
203
+ google-cloud-trace==1.18.0
204
+ google-cloud-translate==3.24.0
205
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
206
+ google-crc32c==1.8.0
207
+ google-genai==1.64.0
208
+ google-generativeai==0.8.6
209
+ google-pasta==0.2.0
210
+ google-resumable-media==2.8.0
211
+ googleapis-common-protos==1.72.0
212
+ googledrivedownloader==1.1.0
213
+ gradio==5.50.0
214
+ gradio_client==1.14.0
215
+ grain==0.2.15
216
+ graphviz==0.21
217
+ greenlet==3.3.2
218
+ groovy==0.1.2
219
+ grpc-google-iam-v1==0.14.3
220
+ grpc-interceptor==0.15.4
221
+ grpcio==1.67.1
222
+ grpcio-health-checking==1.67.1
223
+ grpcio-status==1.71.2
224
+ grpclib==0.4.9
225
+ gspread==6.2.1
226
+ gspread-dataframe==4.0.0
227
+ gym==0.25.2
228
+ gym-notices==0.1.0
229
+ gymnasium==1.2.3
230
+ h11==0.16.0
231
+ h2==4.3.0
232
+ h5netcdf==1.8.1
233
+ h5py==3.15.1
234
+ hdbscan==0.8.41
235
+ hf-xet==1.3.0
236
+ hf_transfer==0.1.9
237
+ highspy==1.13.1
238
+ holidays==0.91
239
+ holoviews==1.22.1
240
+ hpack==4.1.0
241
+ html5lib==1.1
242
+ httpcore==1.0.9
243
+ httpimport==1.4.1
244
+ httplib2==0.31.2
245
+ httptools==0.7.1
246
+ httpx==0.28.1
247
+ httpx-sse==0.4.3
248
+ huggingface-hub==0.27.1
249
+ humanize==4.15.0
250
+ hyperframe==6.1.0
251
+ hyperopt==0.2.7
252
+ ibis-framework==9.5.0
253
+ idna==3.11
254
+ ImageIO==2.37.2
255
+ imageio-ffmpeg==0.6.0
256
+ imagesize==1.4.1
257
+ imbalanced-learn==0.14.1
258
+ immutabledict==4.3.1
259
+ importlib_metadata==8.7.1
260
+ importlib_resources==6.5.2
261
+ imutils==0.5.4
262
+ inequality==1.1.2
263
+ inflect==7.5.0
264
+ iniconfig==2.3.0
265
+ intel-cmplr-lib-ur==2025.3.2
266
+ intel-openmp==2025.3.2
267
+ ipyevents==2.0.4
268
+ ipyfilechooser==0.6.0
269
+ ipykernel==7.2.0
270
+ ipyleaflet==0.20.0
271
+ ipyparallel==8.8.0
272
+ ipython==8.21.0
273
+ ipython-genutils==0.2.0
274
+ ipython-sql==0.5.0
275
+ ipytree==0.2.2
276
+ ipywidgets==7.7.1
277
+ isoduration==20.11.0
278
+ itsdangerous==2.2.0
279
+ jaraco.classes==3.4.0
280
+ jaraco.context==6.1.0
281
+ jaraco.functools==4.4.0
282
+ jax==0.7.2
283
+ jax-cuda12-pjrt==0.7.2
284
+ jax-cuda12-plugin==0.7.2
285
+ jaxlib==0.7.2
286
+ jedi==0.19.2
287
+ jeepney==0.9.0
288
+ jieba==0.42.1
289
+ Jinja2==3.1.6
290
+ jiter==0.13.0
291
+ jmespath==1.1.0
292
+ joblib==1.5.3
293
+ jsonpatch==1.33
294
+ jsonpickle==4.1.1
295
+ jsonpointer==3.0.0
296
+ jsonschema==4.26.0
297
+ jsonschema-specifications==2025.9.1
298
+ jupyter-console==6.6.3
299
+ jupyter-events==0.12.0
300
+ jupyter-leaflet==0.20.0
301
+ jupyter_client==8.8.0
302
+ jupyter_core==5.9.1
303
+ jupyter_kernel_gateway @ git+https://github.com/googlecolab/kernel_gateway@b134e9945df25c2dcb98ade9129399be10788671
304
+ jupyter_server==2.14.0
305
+ jupyter_server_terminals==0.5.4
306
+ jupyterlab_pygments==0.3.0
307
+ jupyterlab_widgets==3.0.16
308
+ jupytext==1.19.1
309
+ kaggle==1.7.4.5
310
+ kagglehub==0.3.13
311
+ keras==3.10.0
312
+ keras-hub==0.21.1
313
+ keras-nlp==0.21.1
314
+ keyring==25.7.0
315
+ keyrings.google-artifactregistry-auth==1.1.2
316
+ kiwisolver==1.4.9
317
+ langchain==1.2.10
318
+ langchain-core==1.2.15
319
+ langgraph==1.0.9
320
+ langgraph-checkpoint==4.0.0
321
+ langgraph-prebuilt==1.0.8
322
+ langgraph-sdk==0.3.9
323
+ langsmith==0.7.6
324
+ lark==1.3.1
325
+ lazy_loader==0.4
326
+ libclang==18.1.1
327
+ libcudf-cu12==26.2.1
328
+ libcugraph-cu12==26.2.0
329
+ libcuml-cu12==26.2.0
330
+ libcuvs-cu12==26.2.0
331
+ libkvikio-cu12==26.2.0
332
+ libpysal==4.14.1
333
+ libraft-cu12==26.2.0
334
+ librmm-cu12==26.2.0
335
+ librosa==0.11.0
336
+ libucx-cu12==1.19.0
337
+ libucxx-cu12==0.48.0
338
+ lightgbm==4.6.0
339
+ linkify-it-py==2.0.3
340
+ llvmlite==0.43.0
341
+ locket==1.0.0
342
+ logical-unification==0.4.7
343
+ lxml==6.0.2
344
+ Mako==1.3.10
345
+ mapclassify==2.10.0
346
+ Markdown==3.10.2
347
+ markdown-it-py==4.0.0
348
+ MarkupSafe==3.0.3
349
+ matplotlib==3.10.0
350
+ matplotlib-inline==0.2.1
351
+ matplotlib-venn==1.1.2
352
+ mcp==1.26.0
353
+ mdit-py-plugins==0.5.0
354
+ mdurl==0.1.2
355
+ metaflow==2.19.20
356
+ mgwr==2.2.1
357
+ miniKanren==1.0.5
358
+ missingno==0.5.2
359
+ mistune==3.2.0
360
+ mizani==0.13.5
361
+ mkl==2025.3.1
362
+ ml_dtypes==0.5.4
363
+ mlxtend==0.23.4
364
+ mmh3==5.2.0
365
+ momepy==0.11.0
366
+ more-itertools==10.8.0
367
+ moviepy==1.0.3
368
+ mpmath==1.3.0
369
+ msgpack==1.1.2
370
+ multidict==6.7.1
371
+ multipledispatch==1.0.0
372
+ multiprocess==0.70.16
373
+ multitasking==0.0.12
374
+ murmurhash==1.0.15
375
+ music21==9.9.1
376
+ namex==0.1.0
377
+ narwhals==2.17.0
378
+ natsort==8.4.0
379
+ nbclassic==1.3.3
380
+ nbclient==0.10.4
381
+ nbconvert==7.17.0
382
+ nbformat==5.10.4
383
+ ndindex==1.10.1
384
+ nest-asyncio==1.6.0
385
+ networkx==3.6.1
386
+ nibabel==5.3.3
387
+ nltk==3.9.1
388
+ notebook==6.5.7
389
+ notebook_shim==0.2.4
390
+ numba==0.60.0
391
+ numba-cuda==0.22.2
392
+ numexpr==2.14.1
393
+ numpy==2.4.2
394
+ nvidia-cublas-cu12==12.1.3.1
395
+ nvidia-cuda-cccl-cu12==12.9.27
396
+ nvidia-cuda-cupti-cu12==12.1.105
397
+ nvidia-cuda-nvcc-cu12==12.5.82
398
+ nvidia-cuda-nvrtc-cu12==12.1.105
399
+ nvidia-cuda-runtime-cu12==12.1.105
400
+ nvidia-cudnn-cu12==9.1.0.70
401
+ nvidia-cufft-cu12==11.0.2.54
402
+ nvidia-curand-cu12==10.3.2.106
403
+ nvidia-cusolver-cu12==11.4.5.107
404
+ nvidia-cusparse-cu12==12.1.0.106
405
+ nvidia-libnvcomp-cu12==5.1.0.21
406
+ nvidia-ml-py==13.590.48
407
+ nvidia-nccl-cu12==2.21.5
408
+ nvidia-nvimgcodec-cu12==0.7.0.11
409
+ nvidia-nvjitlink-cu12==12.9.86
410
+ nvidia-nvtx-cu12==12.1.105
411
+ nvtx==0.2.14
412
+ nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-26.2.0-py3-none-any.whl
413
+ oauth2client==4.1.3
414
+ oauthlib==3.3.1
415
+ omegaconf==2.3.0
416
+ onemkl-license==2025.3.1
417
+ openai==2.23.0
418
+ opencv-contrib-python==4.13.0.92
419
+ opencv-python==4.13.0.92
420
+ opencv-python-headless==4.13.0.92
421
+ openpyxl==3.1.5
422
+ opentelemetry-api==1.38.0
423
+ opentelemetry-exporter-gcp-logging==1.11.0a0
424
+ opentelemetry-exporter-gcp-monitoring==1.11.0a0
425
+ opentelemetry-exporter-gcp-trace==1.11.0
426
+ opentelemetry-exporter-otlp-proto-common==1.38.0
427
+ opentelemetry-exporter-otlp-proto-http==1.38.0
428
+ opentelemetry-proto==1.38.0
429
+ opentelemetry-resourcedetector-gcp==1.11.0a0
430
+ opentelemetry-sdk==1.38.0
431
+ opentelemetry-semantic-conventions==0.59b0
432
+ opt_einsum==3.4.0
433
+ optax==0.2.7
434
+ optree==0.19.0
435
+ orbax-checkpoint==0.11.33
436
+ orjson==3.11.7
437
+ ormsgpack==1.12.2
438
+ osqp==1.1.1
439
+ overrides==7.7.0
440
+ packaging==26.0
441
+ pandas==2.2.2
442
+ pandas-datareader==0.10.0
443
+ pandas-gbq==0.30.0
444
+ pandas-stubs==2.2.2.240909
445
+ pandocfilters==1.5.1
446
+ panel==1.8.7
447
+ param==2.3.2
448
+ parso==0.8.6
449
+ parsy==2.2
450
+ partd==1.4.2
451
+ patsy==1.0.2
452
+ peewee==4.0.0
453
+ peft==0.14.0
454
+ pexpect==4.9.0
455
+ pickleshare==0.7.5
456
+ pillow==11.3.0
457
+ pip3-autoremove==2.0.1
458
+ platformdirs==4.9.2
459
+ plotly==5.24.1
460
+ plotnine==0.14.5
461
+ pluggy==1.6.0
462
+ plum-dispatch==2.7.1
463
+ pointpats==2.5.5
464
+ polars==1.35.2
465
+ polars-runtime-32==1.35.2
466
+ pooch==1.9.0
467
+ portpicker==1.5.2
468
+ preshed==3.0.12
469
+ prettytable==3.17.0
470
+ proglog==0.1.12
471
+ progressbar2==4.5.0
472
+ prometheus_client==0.24.1
473
+ promise==2.3
474
+ prompt_toolkit==3.0.52
475
+ propcache==0.4.1
476
+ prophet==1.3.0
477
+ proto-plus==1.27.1
478
+ protobuf==5.29.6
479
+ psutil==5.9.5
480
+ psycopg2==2.9.11
481
+ psygnal==0.15.1
482
+ ptyprocess==0.7.0
483
+ PuLP==3.3.0
484
+ pure_eval==0.2.3
485
+ py-cpuinfo==9.0.0
486
+ py4j==0.10.9.9
487
+ pyarrow==18.1.0
488
+ pyasn1==0.6.2
489
+ pyasn1_modules==0.4.2
490
+ pycairo==1.29.0
491
+ pycocotools==2.0.11
492
+ pycparser==3.0
493
+ pycryptodomex==3.23.0
494
+ pydantic==2.12.3
495
+ pydantic-settings==2.13.1
496
+ pydantic_core==2.41.4
497
+ pydata-google-auth==1.9.1
498
+ pydot==4.0.1
499
+ pydotplus==2.0.2
500
+ PyDrive2==1.21.3
501
+ pydub==0.25.1
502
+ pyerfa==2.0.1.5
503
+ pygame==2.6.1
504
+ pygit2==1.19.1
505
+ Pygments==2.19.2
506
+ PyGObject==3.48.2
507
+ PyJWT==2.11.0
508
+ pylibcudf-cu12==26.2.1
509
+ pylibcugraph-cu12==26.2.0
510
+ pylibraft-cu12==26.2.0
511
+ pymc==5.28.0
512
+ pynndescent==0.6.0
513
+ pyogrio==0.12.1
514
+ pyomo==6.10.0
515
+ PyOpenGL==3.1.10
516
+ pyOpenSSL==24.2.1
517
+ pyparsing==3.3.2
518
+ pyperclip==1.11.0
519
+ pyproj==3.7.2
520
+ pysal==25.7
521
+ pyshp==3.0.3
522
+ PySocks==1.7.1
523
+ pyspark==4.0.2
524
+ pytensor==2.38.0
525
+ pytest==8.4.2
526
+ python-apt==0.0.0
527
+ python-box==7.4.1
528
+ python-dateutil==2.9.0.post0
529
+ python-dotenv==1.2.1
530
+ python-fasthtml==0.12.47
531
+ python-json-logger==4.0.0
532
+ python-louvain==0.16
533
+ python-multipart==0.0.22
534
+ python-slugify==8.0.4
535
+ python-snappy==0.7.3
536
+ python-utils==3.9.1
537
+ pytz==2025.2
538
+ pyviz_comms==3.0.6
539
+ PyWavelets==1.9.0
540
+ PyYAML==6.0.3
541
+ pyzmq==26.2.1
542
+ quantecon==0.11.0
543
+ raft-dask-cu12==26.2.0
544
+ rapids-dask-dependency==26.2.0
545
+ rapids-logger==0.2.3
546
+ rasterio==1.5.0
547
+ rasterstats==0.20.0
548
+ ratelim==0.1.6
549
+ referencing==0.37.0
550
+ regex==2025.11.3
551
+ requests==2.32.4
552
+ requests-oauthlib==2.0.0
553
+ requests-toolbelt==1.0.0
554
+ requirements-parser==0.9.0
555
+ # Editable install with no version control (retrain-pipelines==0.0.0)
556
+ -e /content/pkg_src
557
+ rfc3339-validator==0.1.4
558
+ rfc3986-validator==0.1.1
559
+ rfc3987-syntax==1.1.0
560
+ rich==14.3.3
561
+ rmm-cu12==26.2.0
562
+ roman-numerals==4.1.0
563
+ roman-numerals-py==4.1.0
564
+ rpds-py==0.30.0
565
+ rpy2==3.5.17
566
+ rsa==4.9.1
567
+ rtree==1.4.1
568
+ ruff==0.15.2
569
+ s3transfer==0.16.0
570
+ safehttpx==0.1.7
571
+ safetensors==0.7.0
572
+ scikit-image==0.25.2
573
+ scikit-learn==1.6.1
574
+ scipy==1.16.3
575
+ scooby==0.11.0
576
+ scs==3.2.11
577
+ seaborn==0.13.2
578
+ SecretStorage==3.5.0
579
+ segregation==2.5.3
580
+ semantic-version==2.10.0
581
+ Send2Trash==2.1.0
582
+ sentence-transformers==5.2.3
583
+ sentencepiece==0.2.1
584
+ sentry-sdk==2.53.0
585
+ setuptools==80.10.2
586
+ shap==0.50.0
587
+ shapely==2.1.2
588
+ shellingham==1.5.4
589
+ simple-parsing==0.1.8
590
+ simplejson==3.20.2
591
+ simsimd==6.5.13
592
+ six==1.17.0
593
+ sklearn-compat==0.1.5
594
+ sklearn-pandas==2.2.0
595
+ slicer==0.0.8
596
+ smart_open==7.5.1
597
+ smmap==5.0.2
598
+ sniffio==1.3.1
599
+ snowballstemmer==3.0.1
600
+ sortedcontainers==2.4.0
601
+ soundfile==0.13.1
602
+ soupsieve==2.8.3
603
+ soxr==1.0.0
604
+ spacy==3.8.11
605
+ spacy-legacy==3.0.12
606
+ spacy-loggers==1.0.5
607
+ spaghetti==1.7.6
608
+ spanner-graph-notebook==1.1.8
609
+ spglm==1.1.0
610
+ Sphinx==8.2.3
611
+ sphinxcontrib-applehelp==2.0.0
612
+ sphinxcontrib-devhelp==2.0.0
613
+ sphinxcontrib-htmlhelp==2.1.0
614
+ sphinxcontrib-jsmath==1.0.1
615
+ sphinxcontrib-qthelp==2.0.0
616
+ sphinxcontrib-serializinghtml==2.0.0
617
+ spint==1.0.7
618
+ splot==1.1.7
619
+ spopt==0.7.0
620
+ spreg==1.8.5
621
+ SQLAlchemy==2.0.47
622
+ sqlalchemy-spanner==1.17.2
623
+ sqlglot==25.20.2
624
+ sqlite-web==0.7.1
625
+ sqlparse==0.5.5
626
+ srsly==2.5.2
627
+ sse-starlette==3.2.0
628
+ stack-data==0.6.3
629
+ stanio==0.5.1
630
+ starlette==0.52.1
631
+ statsmodels==0.14.6
632
+ stringzilla==4.6.0
633
+ stumpy==1.13.0
634
+ sympy==1.13.1
635
+ tables==3.10.2
636
+ tabulate==0.9.0
637
+ tbb==2022.3.1
638
+ tblib==3.2.2
639
+ tcmlib==1.4.1
640
+ tenacity==9.1.4
641
+ tensorboard==2.19.0
642
+ tensorboard-data-server==0.7.2
643
+ tensorflow==2.19.0
644
+ tensorflow-datasets==4.9.9
645
+ tensorflow-hub==0.16.1
646
+ tensorflow-metadata==1.17.3
647
+ tensorflow-probability==0.25.0
648
+ tensorflow-text==2.19.0
649
+ tensorflow_decision_forests==1.12.0
650
+ tensorstore==0.1.81
651
+ termcolor==3.3.0
652
+ terminado==0.18.1
653
+ text-unidecode==1.3
654
+ textblob==0.19.0
655
+ tf-slim==1.1.0
656
+ tf_keras==2.19.0
657
+ thinc==8.3.10
658
+ threadpoolctl==3.6.0
659
+ tifffile==2026.2.20
660
+ tiktoken==0.12.0
661
+ timm==1.0.25
662
+ tinycss2==1.4.0
663
+ tobler==0.13.0
664
+ tokenizers==0.20.3
665
+ toml==0.10.2
666
+ tomlkit==0.13.3
667
+ toolz==0.12.1
668
+ torch==2.5.1+cu121
669
+ torchaudio==2.5.1+cu121
670
+ torchcodec==0.10.0+cu128
671
+ torchdata==0.11.0
672
+ torchsummary==1.5.1
673
+ torchtune==0.6.1
674
+ torchvision==0.20.1+cu121
675
+ tornado==6.5.1
676
+ tqdm==4.67.3
677
+ traitlets==5.14.3
678
+ traittypes==0.2.3
679
+ transformers==4.46.2
680
+ treelite==4.6.1
681
+ treescope==0.1.10
682
+ triton==3.1.0
683
+ trl==0.12.0
684
+ tsfresh==0.21.1
685
+ tweepy==4.16.0
686
+ typeguard==4.5.1
687
+ typer==0.24.1
688
+ typer-slim==0.24.0
689
+ types-pytz==2025.2.0.20251108
690
+ types-setuptools==80.10.0.20260124
691
+ typing-inspection==0.4.2
692
+ typing_extensions==4.15.0
693
+ tyro==1.0.8
694
+ tzdata==2025.3
695
+ tzlocal==5.3.1
696
+ uc-micro-py==1.0.3
697
+ ucxx-cu12==0.48.0
698
+ umap-learn==0.5.11
699
+ umf==1.0.3
700
+ unsloth @ git+https://github.com/unslothai/unsloth.git@0c8c5ed81e423658ab9ae81eac5aab8d18f5d7af
701
+ unsloth_zoo==2024.11.5
702
+ uri-template==1.3.0
703
+ uritemplate==4.2.0
704
+ urllib3==2.5.0
705
+ uuid_utils==0.14.1
706
+ uvicorn==0.41.0
707
+ uvloop==0.22.1
708
+ vega-datasets==0.9.0
709
+ wandb==0.25.0
710
+ wasabi==1.1.3
711
+ watchdog==6.0.0
712
+ watchfiles==1.1.1
713
+ wcwidth==0.6.0
714
+ weasel==0.4.3
715
+ webcolors==25.10.0
716
+ webencodings==0.5.1
717
+ websocket-client==1.9.0
718
+ websockets==15.0.1
719
+ Werkzeug==3.1.6
720
+ wheel==0.46.3
721
+ widgetsnbextension==3.6.10
722
+ wordcloud==1.9.6
723
+ wrapt==2.1.1
724
+ wsproto==1.3.2
725
+ wurlitzer==3.1.1
726
+ xarray==2025.12.0
727
+ xarray-einstats==0.10.0
728
+ xformers==0.0.29.post1
729
+ xgboost==3.2.0
730
+ xlrd==2.0.2
731
+ xxhash==3.6.0
732
+ xyzservices==2025.11.0
733
+ yarl==1.22.0
734
+ ydf==0.15.0
735
+ yellowbrick==1.5
736
+ yfinance==0.2.66
737
+ zict==3.0.0
738
+ zipp==3.23.0
739
+ zstandard==0.25.0
v0.37_20260227_192656740_UTC/retraining_pipeline.py ADDED
@@ -0,0 +1,2265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from unsloth import FastLanguageModel, \
3
+ is_bfloat16_supported, UnslothTrainer, \
4
+ UnslothTrainingArguments
5
+
6
+ import torch
7
+
8
+ import os
9
+ import gc
10
+ import re
11
+ import sys
12
+ import json
13
+ import time
14
+ import shutil
15
+ import logging
16
+ import builtins
17
+
18
+ import importlib.util
19
+ from enum import Enum
20
+ from textwrap import dedent
21
+ from datetime import datetime, \
22
+ timezone
23
+
24
+ import polars as pl
25
+ from polars.exceptions import ComputeError
26
+
27
+ from jinja2 import Environment, FileSystemLoader
28
+
29
+ from huggingface_hub import list_repo_commits
30
+ from datasets import load_dataset, \
31
+ Dataset, DatasetDict
32
+ from datasets.config import HF_DATASETS_CACHE, \
33
+ HF_CACHE_HOME
34
+ from transformers import AutoTokenizer
35
+
36
+ from retrain_pipelines import __version__
37
+ from retrain_pipelines.dataset.hf_utils import \
38
+ get_lazy_df, get_column_info, \
39
+ iterable_dataset_multi_buffer_sampler, \
40
+ push_dataset_version_to_hub
41
+ from retrain_pipelines.dataset.tool_calls import \
42
+ count_tool_occurrences, plot_tools_occurences, \
43
+ column_words_stats, plot_words_count, \
44
+ get_unique_tools
45
+ from retrain_pipelines.utils.hf_utils import \
46
+ get_repo_version, get_new_repo_minor_version, \
47
+ push_files_to_hub_repo_branch
48
+
49
+ from retrain_pipelines.dag_engine.core import \
50
+ TaskPayload, task, dag, DagParam, ctx, UiCss
51
+
52
+ from retrain_pipelines.dag_engine.rp_logging import \
53
+ rp_redirect_stdout
54
+
55
+ from retrain_pipelines.dag_engine.sdk import \
56
+ ExecutionsIterator
57
+
58
+ from retrain_pipelines.utils import create_requirements
59
+
60
+
61
+ #--- helpers ----------------------------------------------------------------------------
62
+
63
+
64
+ logger = logging.getLogger(__name__)
65
+ logger.setLevel(logging.DEBUG)
66
+
67
+
68
+ class LocalServeReadinessEnum(Enum):
69
+ """
70
+ tracking local-serve (infra-validation)
71
+ status using a "3+"-states enum :
72
+ - "-1" for "not applicable"
73
+ (i.e. "model version not blessed"),
74
+ - "0/1" bool for failure/success.
75
+ """
76
+ NOT_APPLICABLE = -1
77
+ FAILURE = 0
78
+ FAILURE_NO_DOCKER = 2
79
+ SUCCESS = 1
80
+
81
+
82
+ def clear_gc():
83
+ """Convenience method to clear
84
+ the content of the garbage collector.
85
+ Forcing it to actually clear
86
+ any cuda tensor it holds.
87
+ """
88
+ for obj in gc.get_objects():
89
+ try:
90
+ if torch.is_tensor(obj) and obj.is_cuda:
91
+ del obj
92
+ except:
93
+ pass
94
+ gc.collect()
95
+
96
+
97
+ #--- retraining-pipeline elements -------------------------------------------------------
98
+
99
+
100
+ @task
101
+ def start() -> TaskPayload:
102
+ logger.info(f"{ctx.pipeline_name} - {ctx.exec_id}")
103
+ logging.getLogger("retrain_pipelines").setLevel(logging.INFO)
104
+
105
+ # inputs validation
106
+ repo_id_pattern = re.compile(
107
+ r"""
108
+ ^ # start
109
+ (?!.*\.\.) # no '..' anywhere
110
+ (?!.*--) # no '--' anywhere
111
+ (?: # legacy: single segment OR namespace/repo
112
+ [A-Za-z0-9._-]+ # legacy: gpt2, bert-base-uncased, etc.
113
+ |
114
+ [A-Za-z0-9._-]+/[A-Za-z0-9._-]+ # namespace/repo_name
115
+ )
116
+ $ # end
117
+ """,
118
+ re.VERBOSE
119
+ )
120
+ ctx.hf_dataset = json.loads(ctx.hf_dataset)
121
+ assert repo_id_pattern.match(ctx.hf_dataset["repo_id"]) is not None, \
122
+ f"Invalid repo_id format: {ctx.hf_dataset['repo_id']!r}"
123
+ ctx.augmentation_rate = float(ctx.augmentation_rate)
124
+ ctx.hf_enrich_dataset = json.loads(ctx.hf_enrich_dataset)
125
+ assert repo_id_pattern.match(ctx.hf_enrich_dataset["repo_id"]) is not None, \
126
+ f"Invalid repo_id format: {ctx.hf_enrich_dataset['repo_id']!r}"
127
+ ctx.enrichment_rate = float(ctx.enrichment_rate)
128
+ assert repo_id_pattern.match(ctx.dataset_repo_id) is not None, \
129
+ f"Invalid repo_id format: {dataset_repo_id!r}"
130
+ assert ctx.polars_engine in ["gpu", "cpu"]
131
+ ctx.hf_base_model = json.loads(ctx.hf_base_model)
132
+ assert repo_id_pattern.match(ctx.hf_base_model["repo_id"]) is not None, \
133
+ f"Invalid repo_id format: {ctx.hf_base_model['repo_id']!r}"
134
+ ctx.cpt_training_args = json.loads(ctx.cpt_training_args)
135
+ ctx.sft_training_args = json.loads(ctx.sft_training_args)
136
+ assert repo_id_pattern.match(ctx.model_repo_id) is not None, \
137
+ f"Invalid repo_id format: {model_repo_id!r}"
138
+
139
+ # GPU availability
140
+ logger.info(torch.cuda.get_device_name(0))
141
+ logger.info(torch.__version__)
142
+ ctx.engine = "cpu" if (
143
+ ctx.polars_engine == "gpu" and
144
+ not torch.cuda.is_available()
145
+ ) else ctx.polars_engine
146
+ logger.debug(f"Polars engine : {ctx.engine}")
147
+
148
+ # hf_dataset
149
+ hf_dataset_dict = \
150
+ get_lazy_df(
151
+ repo_id=ctx.hf_dataset["repo_id"],
152
+ commit_hash=ctx.hf_dataset["commit_hash"],
153
+ config_name=(
154
+ ctx.hf_dataset["config_name"] and
155
+ "" < ctx.hf_dataset["config_name"]
156
+ ),
157
+ hf_token=os.getenv("HF_TOKEN", None)
158
+ )
159
+ try:
160
+ logger.info(f"hf_dataset_dict lazy_df : {hf_dataset_dict['lazy_df']}")
161
+ logger.info(
162
+ f"{hf_dataset_dict['repo_id']}, " +
163
+ f"{hf_dataset_dict['commit_hash']} - " +
164
+ f"{hf_dataset_dict['commit_datetime']}\n" +
165
+ hf_dataset_dict["lazy_df"].explain()
166
+ )
167
+ except ComputeError as ex:
168
+ if "HF_TOKEN" not in os.environ:
169
+ logger.info("Does the Hugging Face-hosted dataset " +
170
+ "require authentication ?",
171
+ file=sys.stderr, flush=True)
172
+ raise ex
173
+ hf_dataset_version = get_repo_version(
174
+ repo_id=hf_dataset_dict["repo_id"],
175
+ revision=hf_dataset_dict["commit_hash"],
176
+ repo_type="dataset",
177
+ hf_token=os.getenv("HF_TOKEN", None)
178
+ )
179
+ hf_dataset_dict["version_label"] = (
180
+ f"{hf_dataset_version[0]}.{hf_dataset_version[1]}"
181
+ if sum(hf_dataset_version) > 0
182
+ else None
183
+ )
184
+ ctx.hf_dataset_dict = hf_dataset_dict
185
+
186
+ # hf_enrich_dataset
187
+ hf_enrich_dataset_dict = \
188
+ get_lazy_df(
189
+ repo_id=ctx.hf_enrich_dataset["repo_id"],
190
+ commit_hash=ctx.hf_enrich_dataset["commit_hash"],
191
+ config_name=(
192
+ ctx.hf_enrich_dataset["config_name"] and
193
+ "" < ctx.hf_enrich_dataset["config_name"]
194
+ ),
195
+ hf_token=os.getenv("HF_TOKEN", None)
196
+ )
197
+ hf_enrich_dataset_version = get_repo_version(
198
+ repo_id=hf_enrich_dataset_dict["repo_id"],
199
+ revision=hf_enrich_dataset_dict["commit_hash"],
200
+ repo_type="dataset",
201
+ hf_token=os.getenv("HF_TOKEN", None)
202
+ )
203
+ hf_enrich_dataset_dict["version_label"] = (
204
+ f"{hf_enrich_dataset_version[0]}.{hf_enrich_dataset_version[1]}"
205
+ if sum(hf_enrich_dataset_version) > 0
206
+ else None
207
+ )
208
+ logger.info(' ; '.join(f"{k}: {hf_enrich_dataset_dict[k]}"
209
+ for k in ['commit_hash',
210
+ 'commit_datetime']))
211
+ ctx.hf_enrich_dataset_dict = hf_enrich_dataset_dict
212
+
213
+ # hf_base_model
214
+ hf_base_model_revision=(
215
+ None if (rev_commit_hash:=ctx.hf_base_model["commit_hash"]) == ""
216
+ else rev_commit_hash
217
+ )
218
+ hf_base_model_commit = list_repo_commits(
219
+ repo_id=ctx.hf_base_model["repo_id"],
220
+ revision=hf_base_model_revision,
221
+ repo_type="model",
222
+ token=os.getenv("HF_TOKEN", None)
223
+ )[0]
224
+ # version major+minor=0 for non retrain-pipelines models
225
+ hf_base_model_version = get_repo_version(
226
+ repo_id=ctx.hf_base_model["repo_id"],
227
+ revision=hf_base_model_revision,
228
+ repo_type="model",
229
+ hf_token=os.getenv("HF_TOKEN", None)
230
+ )
231
+ ctx.hf_base_model_dict = {
232
+ "repo_id": ctx.hf_base_model["repo_id"],
233
+ "version_label": (
234
+ f"{hf_base_model_version[0]}.{hf_base_model_version[1]}"
235
+ if sum(hf_base_model_version) > 0
236
+ else None
237
+ ),
238
+ "commit_hash": hf_base_model_commit.commit_id,
239
+ "commit_datetime": \
240
+ hf_base_model_commit.created_at
241
+ }
242
+
243
+
244
+ ctx.model_version_blessed = False
245
+ ctx.current_blessed_exec = None
246
+ ctx.current_blessed_version_dict = None
247
+
248
+ ctx.retrain_pipelines = f"retrain-pipelines {__version__}"
249
+ ctx.retrain_pipeline_type = os.environ["retrain_pipeline_type"]
250
+
251
+
252
+ ctx.serving_artifacts_local_folder = os.path.realpath(os.path.join(
253
+ os.path.dirname(__file__), "..", "..", "serving_artifacts",
254
+ ctx.pipeline_name, str(ctx.exec_id)
255
+ ))
256
+
257
+ if not os.path.exists(ctx.serving_artifacts_local_folder):
258
+ os.makedirs(ctx.serving_artifacts_local_folder)
259
+
260
+
261
+ ctx.unsloth_dir = os.path.join(
262
+ ctx.serving_artifacts_local_folder,
263
+ "Unsloth"
264
+ )
265
+ logger.debug(f"unsloth_dir : {ctx.unsloth_dir}")
266
+ ctx.cpt_model_dir = os.path.join(ctx.unsloth_dir, "cpt_model")
267
+ ctx.sft_model_dir = os.path.join(ctx.unsloth_dir, "sft_model")
268
+
269
+ return None
270
+
271
+
272
+ @task
273
+ def eda(_) -> None:
274
+ """
275
+ exploratory data analysis.
276
+ """
277
+
278
+ ############################
279
+ # features and label #
280
+ # basic counts #
281
+ ############################
282
+ ctx.records_count = ctx.hf_dataset_dict["lazy_df"] \
283
+ .select(pl.len()).collect(engine=ctx.engine).item()
284
+ ctx.data_schema = get_column_info(
285
+ ctx.hf_dataset_dict["lazy_df"], engine=ctx.engine)
286
+ ############################
287
+
288
+ ############################
289
+ # Answers #
290
+ # tools count #
291
+ ############################
292
+ struct_schema = pl.Struct([
293
+ pl.Field("name",
294
+ pl.String
295
+ ),
296
+ pl.Field("arguments",
297
+ pl.List(pl.String) # we retrieve list of args names
298
+ # (without assigned values)
299
+ )
300
+ ])
301
+ tool_answer_occurrences_df = \
302
+ count_tool_occurrences(
303
+ ctx.hf_dataset_dict["lazy_df"],
304
+ ctx.hf_dataset["attributes"]["answers_attr"],
305
+ struct_schema) \
306
+ .collect(engine=ctx.engine)
307
+ print(f"{tool_answer_occurrences_df['occurrences'].sum():,} " +
308
+ f"query/tool-calls pairs")
309
+ fig = plot_tools_occurences(tool_answer_occurrences_df,
310
+ title_prefix="Dataset answers - ")
311
+ ctx.answers_tools_count_fig = fig
312
+ ############################
313
+
314
+ ############################
315
+ # Query #
316
+ # words count #
317
+ ############################
318
+ queries_max_length = ctx.hf_dataset_dict["lazy_df"].select(
319
+ pl.col(
320
+ ctx.hf_dataset["attributes"]["query_attr"]
321
+ ).str.len_chars().max().alias("max_query_length")
322
+ ).collect(engine=ctx.engine)
323
+ print(f"longuest query counts " +
324
+ f"{queries_max_length['max_query_length'][0]:,} characters")
325
+
326
+ # queries length quartiles
327
+ ctx.query_words_stats = \
328
+ column_words_stats(
329
+ ctx.hf_dataset_dict["lazy_df"],
330
+ ctx.hf_dataset["attributes"]["query_attr"]
331
+ ).collect(engine=ctx.engine)
332
+ print(ctx.query_words_stats.to_pandas().to_string(index=False))
333
+ print("Two thirds of the records have a query with less than " +
334
+ f"{ctx.query_words_stats['q3'][0]} words.")
335
+
336
+ fig = plot_words_count(
337
+ ctx.hf_dataset_dict["lazy_df"],
338
+ column_name=ctx.hf_dataset["attributes"]["query_attr"],
339
+ engine=ctx.engine)
340
+ ctx.words_count_fig = fig
341
+ ############################
342
+
343
+ ############################
344
+ # hf_enrich_dataset #
345
+ # Query words count #
346
+ ############################
347
+ enrich_question_words_stats = \
348
+ column_words_stats(
349
+ ctx.hf_enrich_dataset_dict['lazy_df'],
350
+ ctx.hf_enrich_dataset["query_attribute"],
351
+ column_attr_handler=eval(
352
+ ctx.hf_enrich_dataset["query_attribute_handler"])
353
+ ).collect(engine=ctx.engine)
354
+ print(enrich_question_words_stats.to_pandas()
355
+ .to_string(index=False))
356
+ del enrich_question_words_stats
357
+ ############################
358
+
359
+ return None
360
+
361
+
362
+ @task
363
+ def augment_data(_) -> None:
364
+ """
365
+ Add 'negative' examples, where
366
+ queries do not trigger any tool call.
367
+ To achieve that, we sample long user queries,
368
+ truncate at half words count, and
369
+ associate this to an empty list of tool-calls.
370
+ """
371
+ """
372
+ We only consider :
373
+ - records with longuest queries,
374
+ i.e. queries in the last quartile
375
+ of "queries with most word-counts"
376
+ (this is to avoid that 'truncated' queries
377
+ get really short)
378
+ - records with answers consisting
379
+ in a single tool-call
380
+ (in order to minimize the risk
381
+ that truncating actually gives
382
+ a valid answer with
383
+ one tool-call [or more])
384
+
385
+ Note on flow 'augmentation_rate' :
386
+ we add that many records (at most),
387
+ as quartiles size permits.
388
+ """
389
+
390
+ print("Sampling within the population with more than " +
391
+ str(ctx.query_words_stats['q3'][0]) +
392
+ " words (longest queries quartile) =>")
393
+
394
+ samples_count = \
395
+ int(ctx.records_count * ctx.augmentation_rate)
396
+ print(f"{ctx.augmentation_rate:.1%} would represent " +
397
+ f"{samples_count:,.0f} records to be sampled")
398
+
399
+ eligible_records_df = \
400
+ ctx.hf_dataset_dict["lazy_df"].filter(
401
+ pl.col(
402
+ ctx.hf_dataset["attributes"]["query_attr"]
403
+ )
404
+ .str.extract_all(r"\w+")
405
+ .map_elements(
406
+ lambda arr: len(arr),
407
+ return_dtype=pl.Int16)
408
+ .gt(ctx.query_words_stats['q3'][0])
409
+ & pl.col("answers")
410
+ .map_elements(
411
+ lambda x: len(json.loads(x)) == 1
412
+ if isinstance(x, str)
413
+ else False,
414
+ return_dtype=pl.Boolean)
415
+ ) \
416
+ .collect(engine=ctx.engine)
417
+ eligible_records_count = \
418
+ eligible_records_df.select(pl.len())["len"][0]
419
+ print(f"eligible_records_count : " +
420
+ f"{eligible_records_count:,.0f}")
421
+ samples_count = min(samples_count, eligible_records_count)
422
+ ctx.actual_augmentation_rate = \
423
+ samples_count / ctx.records_count
424
+ print("actual augmentation rate : " +
425
+ f"{ctx.actual_augmentation_rate:.1%}")
426
+ sampled_records_df = eligible_records_df.sample(
427
+ n=samples_count
428
+ )
429
+
430
+ ctx.augmented_records_df = \
431
+ sampled_records_df.with_columns(
432
+ pl.col("query")
433
+ .map_elements(
434
+ lambda query:
435
+ " ".join(
436
+ query.split()[
437
+ :len(query.split()) // 2]),
438
+ return_dtype=pl.Utf8)
439
+ .alias("truncated_query")
440
+ ).select([
441
+ pl.col("truncated_query").alias("query"),
442
+ pl.lit("[]").alias("answers")
443
+ ])
444
+ print(ctx.augmented_records_df.height,
445
+ ctx.augmented_records_df.columns)
446
+
447
+ return None
448
+
449
+
450
+ @task
451
+ def enrich_data(_) -> None:
452
+ """
453
+ Further enrich our dataset with 'negative' records from
454
+ another dataset (can be general-purpose text dataset)
455
+ as specified by the the flow 'hf_enrich_dataset' argument.
456
+ """
457
+ """
458
+ Note : we here use the Hugging Face `datasets` library
459
+ in 'streaming' mode for records sampling.
460
+ """
461
+
462
+ hf_enrich_ds = load_dataset(
463
+ path=ctx.hf_enrich_dataset["repo_id"],
464
+ name=ctx.hf_enrich_dataset["config_name"],
465
+ revision=ctx.hf_enrich_dataset_dict["commit_hash"],
466
+ streaming=True)
467
+ print(hf_enrich_ds["train"])
468
+
469
+ samples_count = \
470
+ int(ctx.records_count * ctx.enrichment_rate)
471
+ print(f"Samplig {samples_count:,.0f} records")
472
+
473
+ query_attribute_handler = \
474
+ eval(ctx.hf_enrich_dataset["query_attribute_handler"])
475
+ samples_iterator = iterable_dataset_multi_buffer_sampler(
476
+ hf_enrich_ds["train"],
477
+ total_samples=samples_count,
478
+ attributes_selector=\
479
+ (lambda x:query_attribute_handler(
480
+ x[ctx.hf_enrich_dataset["query_attribute"]])),
481
+ buffer_size=3_000,
482
+ num_passes=3,
483
+ seed=None
484
+ )
485
+ # Capitalize and add end punctuation if missing
486
+ start_time = time.time()
487
+ print("Starting sample enriching records, " +
488
+ "this may take some time if the source dataset " +
489
+ "has a complex structure..")
490
+ samples_list = [
491
+ s.capitalize() + ("" if s[-1] in ".!?" else "?")
492
+ for s in samples_iterator]
493
+ elapsed_time = time.time() - start_time
494
+ print(f".. sampling completed " +
495
+ f"({int(elapsed_time // 3_600)}h:" +
496
+ f"{int((elapsed_time % 3_600) // 60)}m:" +
497
+ f"{int(elapsed_time % 60)}s).")
498
+ enriched_records_df = pl.DataFrame(
499
+ {"query": samples_list,
500
+ "answers": \
501
+ ["[]"] * \
502
+ len(samples_list)}
503
+ )
504
+ ctx.enriched_records_df = enriched_records_df
505
+
506
+ return None
507
+
508
+
509
+ @task(ui_css=UiCss(background="#FF9900", color="#111827", border="#1F2937"))
510
+ def dataset_to_hub(_) -> None:
511
+ """
512
+ Push to hub dataset version
513
+ - continued pre-training dataset
514
+ - training and validation splits of the
515
+ augmented and enriched
516
+ supervised finetuning dataset
517
+ - readme with versioning info
518
+ """
519
+
520
+ #############################
521
+ # case of user-provided #
522
+ # documentation artifact(s) #
523
+ #############################
524
+ # note that user can provide either
525
+ # 'pipeline_card.py' or 'template.html'
526
+ # or 'dataset_readme.py'
527
+ # or 'dataset_readme_template.md'
528
+ # or 'model_readme.py'
529
+ # or 'model_readme_template.md'
530
+ # or any combination of those
531
+ # when specifying custom
532
+ # 'pipeline_card_artifacts_path'
533
+ if (
534
+ "dataset_readme_template.md" in
535
+ os.listdir(ctx.pipeline_card_artifacts_path)
536
+ ):
537
+ template_dir = ctx.pipeline_card_artifacts_path
538
+ else:
539
+ template_dir = os.path.dirname(
540
+ importlib.util.find_spec(
541
+ f"retrain_pipelines.pipeline_card."+
542
+ f"{os.getenv('retrain_pipeline_type')}"
543
+ ).origin)
544
+ print(f"template_dir : '{template_dir}'")
545
+ #############################
546
+ if "dataset_readme.py" in os.listdir(
547
+ ctx.pipeline_card_artifacts_path):
548
+ from retrain_pipelines.utils import \
549
+ get_get_dataset_readme_content
550
+ get_dataset_readme_content = \
551
+ get_get_dataset_readme_content(
552
+ ctx.pipeline_card_artifacts_path)
553
+ else:
554
+ from retrain_pipelines.pipeline_card import \
555
+ get_dataset_readme_content
556
+ #############################
557
+
558
+
559
+ #############################
560
+ # augmented & enriched #
561
+ # finetuning dataset #
562
+ #############################
563
+ merged_df = pl.concat([
564
+ # dataset
565
+ ctx.hf_dataset_dict["lazy_df"].select([
566
+ ctx.hf_dataset["attributes"]["query_attr"],
567
+ ctx.hf_dataset["attributes"]["answers_attr"]
568
+ ]).collect(engine=ctx.engine),
569
+ # truncated queries augmentation
570
+ ctx.augmented_records_df,
571
+ # enriching dataset
572
+ ctx.enriched_records_df
573
+ ]).sample(
574
+ # shuffling
575
+ fraction=1,
576
+ shuffle=True,
577
+ with_replacement=False
578
+ )
579
+ merged_df = merged_df.sample(fraction=1, shuffle=True)
580
+ merged_df.rechunk()
581
+ print(("merged_df", f"{merged_df.shape[0]:,.0F}",
582
+ merged_df.columns))
583
+
584
+ pandas_df = merged_df.to_pandas()
585
+ train_size = int(0.8 * len(pandas_df))
586
+ print(f"validation : {len(pandas_df) - train_size}")
587
+ sft_dataset = DatasetDict({
588
+ "train": Dataset.from_pandas(pandas_df[:train_size]),
589
+ "validation": Dataset.from_pandas(pandas_df[train_size:])
590
+ })
591
+ #############################
592
+
593
+ #############################
594
+ # continued pre-training #
595
+ # dataset #
596
+ #############################
597
+ struct_schema = pl.Struct([
598
+ pl.Field("name", pl.String),
599
+ pl.Field("description", pl.String),
600
+ pl.Field(
601
+ "parameters",
602
+ pl.String # Use String to allow
603
+ # for varying structures
604
+ # (different tools indeed having
605
+ # different sets of parameters
606
+ # i.e. different parameters counts,
607
+ # datatypes and names)
608
+ # so parsing must be tolerant.
609
+ )
610
+ ])
611
+ unique_tools_df = get_unique_tools(
612
+ ctx.hf_dataset_dict["lazy_df"],
613
+ tools_attr_name=\
614
+ ctx.hf_dataset["attributes"]["tools_attr"],
615
+ struct_schema=struct_schema
616
+ ).collect(engine=ctx.engine)
617
+ unique_tools_arrow_table = unique_tools_df.to_arrow()
618
+ ctx.unique_tools_dataset = \
619
+ Dataset(unique_tools_arrow_table)
620
+ print(ctx.unique_tools_dataset)
621
+ #############################
622
+
623
+ #############################
624
+ # DatasetDict #
625
+ # with multiple tables #
626
+ #############################
627
+ dataset_dict = DatasetDict({
628
+ "continued_pre_training": \
629
+ ctx.unique_tools_dataset,
630
+ "supervised_finetuning": sft_dataset
631
+ })
632
+ print(dataset_dict, flush=True)
633
+ #############################
634
+
635
+ #############################
636
+ # dataset README #
637
+ # from template #
638
+ #############################
639
+ commit_datetime = datetime.utcnow()
640
+ new_dataset_version_label = get_new_repo_minor_version(
641
+ repo_id=ctx.dataset_repo_id,
642
+ repo_type="dataset",
643
+ hf_token=os.getenv("HF_TOKEN", None))
644
+ readme_content = get_dataset_readme_content(
645
+ template_folder=template_dir,
646
+
647
+ hf_dataset_dict=ctx.hf_dataset_dict,
648
+ hf_enrich_dataset_dict=ctx.hf_enrich_dataset_dict,
649
+ dataset_dict=dataset_dict,
650
+
651
+ augmentation_rate=ctx.actual_augmentation_rate,
652
+ enrichment_rate=ctx.enrichment_rate,
653
+
654
+ version_label=new_dataset_version_label,
655
+ commit_datetime=commit_datetime,
656
+
657
+ pipeline_name=ctx.pipeline_name,
658
+ exec_id=ctx.exec_id,
659
+ engine=ctx.engine
660
+ )
661
+ #############################
662
+
663
+ dataset_commit_hash = push_dataset_version_to_hub(
664
+ repo_id=ctx.dataset_repo_id,
665
+ version_label=new_dataset_version_label,
666
+ timestamp_str=commit_datetime.strftime(
667
+ "%Y-%m-%d %H:%M:%S UTC"),
668
+ dataset_dict=dataset_dict,
669
+ dataset_readme_content=readme_content,
670
+ hf_token=os.getenv("HF_TOKEN", None)
671
+ )
672
+ if not dataset_commit_hash:
673
+ raise Exception(
674
+ "Failed to publish dataset version.")
675
+ print(f"https://huggingface.co/datasets/{ctx.dataset_repo_id}" +
676
+ f"/blob/{dataset_commit_hash}/README.md")
677
+ ctx.dataset_commit_dict = {
678
+ "repo_id": ctx.dataset_repo_id,
679
+ "commit_hash": dataset_commit_hash,
680
+ "version_label": new_dataset_version_label,
681
+ "commit_datetime": commit_datetime,
682
+ }
683
+
684
+ return None
685
+
686
+
687
+ @task
688
+ def continued_pre_training(_) -> None:
689
+ """
690
+ Gives the base model some additional intrinsic knowkledge
691
+ through continued pre-training.
692
+ See unsloth.ai/blog/contpretraining
693
+ """
694
+ from retrain_pipelines.model.hf_utils import \
695
+ plot_log_history
696
+
697
+ #######################################
698
+ # base-model and associated tokenizer #
699
+ # from Hub (or local cache) #
700
+ #######################################
701
+ ctx.max_seq_length = 2048
702
+ model, tokenizer = FastLanguageModel.from_pretrained(
703
+ model_name=ctx.hf_base_model_dict["repo_id"],
704
+ revision=ctx.hf_base_model_dict["commit_hash"],
705
+ max_seq_length=ctx.max_seq_length,
706
+ dtype=None,
707
+ load_in_4bit=False,
708
+ # case of a gated or private base-model
709
+ token=os.getenv("HF_TOKEN", None)
710
+ )
711
+ #######################################
712
+
713
+ #######################################
714
+ # dataset prompt_template mapping #
715
+ #######################################
716
+ tools_dataset = DatasetDict(
717
+ {"train": ctx.unique_tools_dataset})
718
+ print(tools_dataset)
719
+ tool_prompt_template = "tool: {}"
720
+ def formatting_prompts_func(tools_batch):
721
+ tools_batch = tools_batch["tool"]
722
+ outputs = []
723
+ for tool in tools_batch:
724
+ # Must add EOS_TOKEN,
725
+ # otherwise generation will go on forever!
726
+ text = tool_prompt_template.format(tool) + \
727
+ tokenizer.eos_token
728
+ outputs.append(text)
729
+ return { "tools" : outputs, }
730
+
731
+ import dis as _dis_module
732
+ # silence dis.py bytecode spam (Unsloth monkey-patch side-effect)
733
+ _dis_module.print = lambda *a, **kw: None
734
+ cpt_dataset = tools_dataset["train"].map(
735
+ formatting_prompts_func, batched=True,)
736
+ del _dis_module.print
737
+ #######################################
738
+
739
+ #######################################
740
+ # PEFT adapter #
741
+ # for continued pre-training #
742
+ #######################################
743
+ model = FastLanguageModel.get_peft_model(
744
+ model,
745
+ r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
746
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
747
+ "gate_proj", "up_proj", "down_proj",
748
+ # Add for continued pretraining
749
+ "embed_tokens", "lm_head",],
750
+ lora_alpha = 32,
751
+ lora_dropout = 0, # Supports any, 0 is optimized
752
+ bias = "none", # Supports any, "none" is optimized
753
+ # True or "unsloth" for very long context
754
+ use_gradient_checkpointing = "unsloth",
755
+ use_rslora = True, # rank-stabilized LoRA
756
+ loftq_config = None, # LoftQ
757
+ #random_state = 3407,
758
+ )
759
+ #######################################
760
+
761
+ #######################################
762
+ # cpt_trainer #
763
+ #######################################
764
+ if (
765
+ "records_cap" in ctx.cpt_training_args and
766
+ ctx.cpt_training_args["records_cap"] is not None and
767
+ isinstance(ctx.cpt_training_args["records_cap"], int)
768
+ ):
769
+ cpt_dataset = cpt_dataset.take(
770
+ ctx.cpt_training_args["records_cap"])
771
+ print(f"cpt_dataset : {cpt_dataset}")
772
+
773
+ train_args = UnslothTrainingArguments(
774
+ # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_strategy
775
+ per_device_train_batch_size=2,
776
+ gradient_accumulation_steps=8,
777
+
778
+ **{k: v for k, v in ctx.cpt_training_args.items()
779
+ if k != "records_cap"},
780
+
781
+ # 2 to 10x smaller learning rate
782
+ # for the embedding matrices
783
+ learning_rate=5e-5,
784
+ embedding_learning_rate=1e-5,
785
+
786
+ fp16=not is_bfloat16_supported(),
787
+ bf16=is_bfloat16_supported(),
788
+ logging_steps=1,
789
+ optim="adamw_8bit",
790
+ weight_decay=0.01,
791
+ lr_scheduler_type="linear",
792
+ #seed=3407,
793
+
794
+ output_dir=os.path.join(
795
+ ctx.unsloth_dir, "outputs", "cpt"),
796
+ save_total_limit = 2,
797
+
798
+ report_to="tensorboard",
799
+ logging_dir=os.path.join(
800
+ ctx.sft_model_dir,
801
+ "runs", "cpt")
802
+ )
803
+
804
+ # silence dis.py bytecode spam (Unsloth monkey-patch side-effect)
805
+ _dis_module.print = lambda *a, **kw: None
806
+ trainer = UnslothTrainer(
807
+ model=model, tokenizer=tokenizer,
808
+ train_dataset=cpt_dataset,
809
+ dataset_text_field="tools",
810
+ max_seq_length=ctx.max_seq_length,
811
+ dataset_num_proc=2,
812
+ args=train_args,
813
+ )
814
+ del _dis_module.print
815
+ #######################################
816
+
817
+ #######################################
818
+ # Show current memory stats #
819
+ #######################################
820
+ torch.cuda.ipc_collect()
821
+ torch.cuda.empty_cache()
822
+ _ = gc.collect()
823
+
824
+ gpu_stats = torch.cuda.get_device_properties(0)
825
+ ctx.start_gpu_memory = \
826
+ round(torch.cuda.max_memory_reserved()
827
+ / 1024 / 1024 / 1024, 3)
828
+ ctx.max_memory = \
829
+ round(gpu_stats.total_memory
830
+ / 1024 / 1024 / 1024, 3)
831
+ print(f"GPU = {gpu_stats.name}. " +
832
+ f"Max memory = {ctx.max_memory} GB.")
833
+ print(f"{ctx.start_gpu_memory} GB of memory reserved.")
834
+ #######################################
835
+
836
+ ctx.cpt_traces_file_fullname = os.path.join(
837
+ ctx.unsloth_dir, "cpt_trainer_traces.txt")
838
+ logger.info(
839
+ "Training started. " +
840
+ f"Check [underline]{ctx.cpt_traces_file_fullname}[/] for live traces " +
841
+ "or go watch your [white bold]TensorBoard[/] charts live updates !"
842
+ )
843
+ with open(ctx.cpt_traces_file_fullname, 'w') as f:
844
+ with rp_redirect_stdout(f):
845
+ trainer_stats = trainer.train()
846
+ print(f"{trainer_stats.metrics['train_runtime']} " +
847
+ f"seconds used for CPT training " +
848
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
849
+ f" minutes).")
850
+
851
+ ctx.cpt_log_history = trainer.state.log_history
852
+ ctx.cpt_log_history_fig = \
853
+ plot_log_history(
854
+ ctx.cpt_log_history,
855
+ title="Continued pretraining loss"
856
+ )
857
+ del trainer
858
+ # logger.debug(f"Continued pretraining loss curve : {ctx.cpt_log_history}")
859
+
860
+ model.save_pretrained_merged(
861
+ save_directory=ctx.cpt_model_dir,
862
+ tokenizer=tokenizer,
863
+ save_method="lora"
864
+ )
865
+ print(f"cpt_model_dir : {ctx.cpt_model_dir}\n")
866
+
867
+ # vRAM & RAM cleanup
868
+ # (incl. force-delete all CUDA tensors in gc)
869
+ del model
870
+ del tokenizer
871
+ clear_gc()
872
+ torch.cuda.empty_cache()
873
+ torch.cuda.synchronize()
874
+ print(f"After cleanup: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
875
+
876
+ return None
877
+
878
+
879
+ @task
880
+ def supervised_finetuning(_) -> None:
881
+ """
882
+ Trains the model on tool-calling
883
+ task specialization.
884
+ """
885
+ from retrain_pipelines.model.hf_utils import \
886
+ plot_log_history
887
+
888
+ model, tokenizer = FastLanguageModel.from_pretrained(
889
+ model_name=ctx.cpt_model_dir,
890
+ max_seq_length=ctx.max_seq_length,
891
+ dtype=None,
892
+ load_in_4bit=False,
893
+ )
894
+ # !!!! bug fix BEGIN !!!!
895
+ # otherwise, 'embed_tokens' and 'lm_head'
896
+ # trained during CPT are "ignored",
897
+ # i.e. not saved after SFT
898
+ # (note that, alternatively, we could also
899
+ # do this fix after sft-training and
900
+ # just before saving ;
901
+ # which would be equivalent to
902
+ # freezing embeddings during finetuning
903
+ # for better pretrained knowledge retention)
904
+ # @see https://www.reddit.com/r/unsloth/comments/1dtzcd6/fastlanguagemodelpatch_peft_model_changing/
905
+ model.model.model.embed_tokens.modules_to_save.default.to(
906
+ device="cuda:0",
907
+ dtype=torch.float32,
908
+ non_blocking=True)
909
+ model.model.model.embed_tokens.modules_to_save.default \
910
+ .requires_grad_(True)
911
+ model.model.lm_head.modules_to_save.default.to(
912
+ device="cuda:0",
913
+ dtype=torch.float32,
914
+ non_blocking=True)
915
+ model.model.lm_head.modules_to_save.default \
916
+ .requires_grad_(True)
917
+ # !!!! bug fix END !!!!
918
+
919
+ #######################################
920
+ # dataset prompt_template mapping #
921
+ #######################################
922
+ # download from Hub (or get from local cache)
923
+ queries_dataset = load_dataset(
924
+ path=ctx.dataset_commit_dict["repo_id"],
925
+ name="supervised_finetuning",
926
+ revision=ctx.dataset_commit_dict["commit_hash"],
927
+ token=os.getenv("HF_TOKEN", None))
928
+ print(f"HF_DATASETS_CACHE : {HF_DATASETS_CACHE}") # HF_CACHE_HOME
929
+ ctx.sft_prompt_template = dedent("""
930
+ You specialize in generating tool calls. Given a query, your task is to return a list of tool calls based on your knowledge of known tools.
931
+
932
+ Rules:
933
+ 1. You can only use tools you know. Do not create new tools under any circumstances.
934
+ 2. If a query does not match any known tool, return an empty list ([]).
935
+ 3. If information is missing to use a known tool, do not attempt to use it.
936
+ 4. Your response must always be a valid JSON array, and nothing else.
937
+
938
+ Be precise and do not guess.
939
+
940
+ # query:
941
+ {}
942
+ # response:
943
+ {}
944
+ """).strip()
945
+ tokenizer.chat_template = ctx.sft_prompt_template
946
+
947
+ EOS_TOKEN = tokenizer.eos_token
948
+ def formatting_prompts_func(records):
949
+ query = records["query"]
950
+ tools = records["answers"]
951
+ outputs = []
952
+ for query, tools in zip(query, tools):
953
+ # Must add EOS_TOKEN,
954
+ # otherwise your generation will go on forever
955
+ text = ctx.sft_prompt_template.format(query, tools) \
956
+ + EOS_TOKEN
957
+ outputs.append(text)
958
+ return { "text" : outputs, }
959
+
960
+ import dis as _dis_module
961
+ # silence dis.py bytecode spam (Unsloth monkey-patch side-effect)
962
+ _dis_module.print = lambda *a, **kw: None
963
+ sft_train_dataset = queries_dataset["train"].map(
964
+ formatting_prompts_func, batched=True)
965
+ sft_valid_dataset = queries_dataset["validation"].map(
966
+ formatting_prompts_func, batched=True,)
967
+ del _dis_module.print
968
+ #######################################
969
+
970
+ #######################################
971
+ # PEFT adapter #
972
+ # for supervised finetuning #
973
+ #######################################
974
+ # for cases where CPT has been merged into overall model
975
+ # otherwize, keep on training current LoRa adapter
976
+ # model = FastLanguageModel.get_peft_model(
977
+ # model,
978
+ # r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
979
+ # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
980
+ # "gate_proj", "up_proj", "down_proj"],
981
+ # lora_alpha = 32,
982
+ # lora_dropout = 0, # Supports any, but = 0 is optimized
983
+ # bias = "none", # Supports any, but = "none" is optimized
984
+ # # True or "unsloth" for very long context
985
+ # use_gradient_checkpointing = "unsloth",
986
+ # random_state = 3407,
987
+ # use_rslora = True, # rank stabilized LoRA
988
+ # loftq_config = None, # LoftQ
989
+ # )
990
+ #######################################
991
+
992
+ #######################################
993
+ # sft_trainer #
994
+ #######################################
995
+ split = sft_train_dataset.train_test_split(
996
+ test_size=1000,
997
+ #seed=42
998
+ )
999
+ train_dataset = split['train']
1000
+ eval_dataset = split['test']
1001
+ if (
1002
+ "records_cap" in ctx.sft_training_args and
1003
+ ctx.sft_training_args["records_cap"] is not None and
1004
+ isinstance(ctx.sft_training_args["records_cap"], int)
1005
+ ):
1006
+ train_dataset = train_dataset.take(
1007
+ ctx.sft_training_args["records_cap"])
1008
+ eval_dataset = eval_dataset.take(
1009
+ ctx.sft_training_args["records_cap"])
1010
+ print(f"train_dataset : {train_dataset}")
1011
+ print(f"eval_dataset : {eval_dataset}")
1012
+
1013
+ train_args = UnslothTrainingArguments(
1014
+ per_device_train_batch_size=2,
1015
+ gradient_accumulation_steps=8,
1016
+
1017
+ **{k: v for k, v in ctx.sft_training_args.items()
1018
+ if k != "records_cap"},
1019
+
1020
+ per_device_eval_batch_size=2,
1021
+ eval_steps=200,
1022
+ eval_strategy="steps",
1023
+ do_eval=True,
1024
+
1025
+ learning_rate=5e-5,
1026
+ # embedding_learning_rate=1e-5, # Optionally here
1027
+
1028
+ fp16=not is_bfloat16_supported(),
1029
+ bf16=is_bfloat16_supported(),
1030
+
1031
+ optim="adamw_8bit",
1032
+ weight_decay=0.00,
1033
+ lr_scheduler_type="linear",
1034
+ #seed=3407,
1035
+
1036
+ output_dir=os.path.join(
1037
+ ctx.unsloth_dir, "outputs", "sft"),
1038
+ save_total_limit=2,
1039
+
1040
+ disable_tqdm=True,
1041
+ logging_steps=1,
1042
+ report_to="tensorboard",
1043
+ logging_dir=os.path.join(
1044
+ ctx.sft_model_dir,
1045
+ "runs", "sft")
1046
+ )
1047
+
1048
+ # silence dis.py bytecode spam (Unsloth monkey-patch side-effect)
1049
+ _dis_module.print = lambda *a, **kw: None
1050
+ trainer = UnslothTrainer(
1051
+ model=model, tokenizer=tokenizer,
1052
+ train_dataset=train_dataset,
1053
+ dataset_text_field="text",
1054
+ eval_dataset=eval_dataset,
1055
+ max_seq_length=ctx.max_seq_length,
1056
+ dataset_num_proc=8,
1057
+ args=train_args
1058
+ )
1059
+ del _dis_module.print
1060
+ trainer.can_return_loss = True
1061
+ #######################################
1062
+
1063
+ #######################################
1064
+ # Show current memory stats #
1065
+ #######################################
1066
+ torch.cuda.ipc_collect()
1067
+ torch.cuda.empty_cache()
1068
+ _ = gc.collect()
1069
+
1070
+ used_memory = \
1071
+ round(torch.cuda.max_memory_reserved()
1072
+ /1024/1024/1024, 3)
1073
+ used_memory_for_lora = \
1074
+ round(used_memory-ctx.start_gpu_memory, 3)
1075
+ used_percentage = \
1076
+ round(used_memory/ctx.max_memory*100, 3)
1077
+ lora_percentage = \
1078
+ round(used_memory_for_lora/ctx.max_memory*100,
1079
+ 3)
1080
+ print(f"Peak reserved memory = " +
1081
+ f"{used_memory} GB.")
1082
+ print(f"Peak reserved memory for " +
1083
+ f"training = {used_memory_for_lora} " +
1084
+ f"GB.")
1085
+ print(f"Peak reserved memory % of " +
1086
+ f"max memory = {used_percentage} %.")
1087
+ print(f"Peak reserved memory for SFT training " +
1088
+ f"% of max memory = {lora_percentage} %.")
1089
+ #######################################
1090
+
1091
+ ctx.sft_traces_file_fullname = os.path.join(
1092
+ ctx.unsloth_dir, "sft_trainer_traces.txt")
1093
+ logger.info(
1094
+ "Training started. " +
1095
+ f"Check [underline]{ctx.sft_traces_file_fullname}[/] for live traces " +
1096
+ "or go watch your [white bold]TensorBoard[/] charts live updates !"
1097
+ )
1098
+ with open(ctx.sft_traces_file_fullname, 'w') as f:
1099
+ with rp_redirect_stdout(f):
1100
+ trainer_stats = trainer.train()
1101
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1102
+ f"seconds used for training " +
1103
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1104
+ f" minutes).")
1105
+
1106
+ ctx.sft_log_history = trainer.state.log_history
1107
+ ctx.sft_log_history_fig = \
1108
+ plot_log_history(
1109
+ ctx.sft_log_history,
1110
+ title="Supervised finetuning loss"
1111
+ )
1112
+ del trainer
1113
+
1114
+ model.save_pretrained_merged(
1115
+ ctx.sft_model_dir, tokenizer,
1116
+ save_method = "lora"
1117
+ )
1118
+ print(f"sft_model_dir : {ctx.sft_model_dir}\n")
1119
+
1120
+ # vRAM & RAM cleanup
1121
+ # (incl. force-delete all CUDA tensors in gc)
1122
+ del model
1123
+ del tokenizer
1124
+ clear_gc()
1125
+ torch.cuda.empty_cache()
1126
+ torch.cuda.synchronize()
1127
+ print(f"After cleanup: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
1128
+
1129
+ return None
1130
+
1131
+
1132
+ @task
1133
+ def evaluate_model(_) -> None:
1134
+ """
1135
+ Batch inference on the SFT validation dataset.
1136
+ """
1137
+ from retrain_pipelines.model import \
1138
+ infer_validation, compute_counts_n_metrics, \
1139
+ plot_validation_completions
1140
+
1141
+ ######################################################
1142
+ # loading trained adapter #
1143
+ ######################################################
1144
+ # Unsloth [and hf transformers before it] #
1145
+ # (if loading both model & tokenizer at once #
1146
+ # same as we did in prior tasks, but now #
1147
+ # with tokenizer.chat_template being set #
1148
+ # in tokenizer.config) is forcing on us some kind of #
1149
+ # chat_template format hard-requirements. #
1150
+ ######################################################
1151
+ # load base from cache
1152
+ # (with base tokenizer, which we ignore)
1153
+ model, _ = FastLanguageModel.from_pretrained(
1154
+ model_name=ctx.hf_base_model_dict["repo_id"],
1155
+ revision=ctx.hf_base_model_dict["commit_hash"],
1156
+ max_seq_length=ctx.max_seq_length,
1157
+ dtype=None,
1158
+ load_in_4bit=False,
1159
+ # case of a gated or private base-model
1160
+ token=os.getenv("HF_TOKEN", None)
1161
+ )
1162
+ model = FastLanguageModel.for_inference(model)
1163
+ # load our CPT+SFT trained & locally-saved adapter
1164
+ model.load_adapter(peft_model_id=ctx.sft_model_dir)
1165
+ # Separately load our (potentially trained &)
1166
+ # locally-saved adapter-tokenizer
1167
+ # (loading it below via HF and not Unsloth)
1168
+ tokenizer = AutoTokenizer.from_pretrained(
1169
+ pretrained_model_name_or_path=ctx.sft_model_dir
1170
+ )
1171
+ ######################################################
1172
+
1173
+ ######################################################
1174
+ # validation dataset #
1175
+ ######################################################
1176
+ # download from Hub (or get from local cache)
1177
+ queries_dataset = load_dataset(
1178
+ path=ctx.dataset_commit_dict["repo_id"],
1179
+ name="supervised_finetuning",
1180
+ revision=ctx.dataset_commit_dict["commit_hash"],
1181
+ token=os.getenv("HF_TOKEN", None))
1182
+ if (
1183
+ "records_cap" in ctx.sft_training_args and
1184
+ ctx.sft_training_args["records_cap"] is not None and
1185
+ isinstance(ctx.sft_training_args["records_cap"], int)
1186
+ ):
1187
+ validation_data = queries_dataset["validation"].take(
1188
+ ctx.sft_training_args["records_cap"])
1189
+ else:
1190
+ validation_data = queries_dataset["validation"]
1191
+ print(validation_data, flush=True)
1192
+ ######################################################
1193
+
1194
+ ctx.max_new_tokens = 400
1195
+ start_time = time.time()
1196
+ validation_results = infer_validation(
1197
+ tokenizer=tokenizer,
1198
+ model=model,
1199
+ validation_data=validation_data,
1200
+ prompt_template=tokenizer.chat_template,
1201
+ batch_size=32, # 64,
1202
+ queries_attr_name=\
1203
+ ctx.hf_dataset["attributes"]["query_attr"],
1204
+ answers_attr_name=\
1205
+ ctx.hf_dataset["attributes"]["answers_attr"],
1206
+ max_new_tokens=ctx.max_new_tokens,
1207
+ device="cuda"
1208
+ )
1209
+ print("infer_validation - Elapsed time: " +
1210
+ f"{(time.time() - start_time):.2f} seconds")
1211
+ ctx.validation_results = validation_results # <= to artifacts store
1212
+
1213
+ eval_df = pl.LazyFrame(validation_results)
1214
+
1215
+ records = eval_df.with_columns(
1216
+ (pl.col("answer") == pl.col("completion")) \
1217
+ .alias("is_ground_truth_identical")
1218
+ ).collect() #engine=ctx.engine)
1219
+ print("perfect characters-match accuracy : " +
1220
+ str(records['is_ground_truth_identical'].mean()))
1221
+
1222
+ eval_metrics_df = compute_counts_n_metrics(
1223
+ eval_df, is_format_fault_tolerant=True)
1224
+ overall_metrics_df = eval_metrics_df.select([
1225
+ pl.col("precision").mean(),
1226
+ pl.col("recall").mean(),
1227
+ pl.col("f1").mean(),
1228
+ pl.col("jaccard").mean()
1229
+ ]).collect() #engine=ctx.engine)
1230
+ ctx.perf_metrics = overall_metrics_df.row(0, named=True)
1231
+ print(ctx.perf_metrics)
1232
+
1233
+ ctx.validation_completions_fig = \
1234
+ plot_validation_completions(
1235
+ eval_metrics_df, engine=ctx.engine)
1236
+
1237
+ # vRAM & RAM cleanup
1238
+ # (incl. force-delete all CUDA tensors in gc)
1239
+ del model
1240
+ del tokenizer
1241
+ clear_gc()
1242
+ torch.cuda.empty_cache()
1243
+ torch.cuda.synchronize()
1244
+ print(f"After cleanup: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
1245
+
1246
+ return None
1247
+
1248
+
1249
+ @task
1250
+ def model_version_blessing(_) -> None:
1251
+ """
1252
+ Comparing newly-retrained model version
1253
+ against best-performing predecessor.
1254
+ """
1255
+ """
1256
+ Note: for Hugging Face integrated pipelines,
1257
+ we compare against lastest commit of main branch
1258
+ of the model repository there.
1259
+ When it comes to local "mf_run_id" of the pipeline run
1260
+ having generated that best prior model version
1261
+ (retrieved from model card metadata from HF yaml section),
1262
+ we check against records of the herein ML-framework instance,
1263
+ as "prior best version" of the model here beign retrained
1264
+ may have been originated from another one
1265
+ than the one executing the current retraining
1266
+ (in which case, we simply don't includ a "local" hyperlink
1267
+ in the model version pipeline_cards that will be
1268
+ produced later in the herein pipeline run).
1269
+ """
1270
+ from retrain_pipelines.model.hf_utils import \
1271
+ current_blessed_model_version_dict
1272
+
1273
+ main_perf_metric_name = "jaccard"
1274
+
1275
+ current_blessed_version_dict = \
1276
+ current_blessed_model_version_dict(
1277
+ repo_id=ctx.model_repo_id,
1278
+ hf_token=os.getenv("HF_TOKEN", None)
1279
+ )
1280
+ print("current_blessed_version_dict : " +
1281
+ str(current_blessed_version_dict))
1282
+
1283
+ if current_blessed_version_dict is None:
1284
+ print("case 'no prior blessed model version found"
1285
+ " => blessing.'")
1286
+ ctx.model_version_blessed = True
1287
+
1288
+ elif (
1289
+ main_perf_metric_name in
1290
+ current_blessed_version_dict["perf_metrics"]
1291
+ ):
1292
+ current_blessed_exec_id = \
1293
+ current_blessed_version_dict["exec_id"]
1294
+ print(f"current_blessed_exec_id : {current_blessed_exec_id}")
1295
+ current_blessed_metric_value = \
1296
+ current_blessed_version_dict[
1297
+ "perf_metrics"][main_perf_metric_name]
1298
+
1299
+ ctx.model_version_blessed = (
1300
+ ctx.perf_metrics[main_perf_metric_name] >=
1301
+ current_blessed_metric_value
1302
+ )
1303
+
1304
+ # ctx.model_version_blessed = False ### DEBUG - DELETE ###
1305
+
1306
+ if not ctx.model_version_blessed:
1307
+ ctx.current_blessed_version_dict = \
1308
+ current_blessed_version_dict
1309
+ # may have failed after the "pipeline_card" task,
1310
+ # so we do not filter on success
1311
+ for execution in ExecutionsIterator(
1312
+ exec_name=ctx.pipeline_name,
1313
+ page_size=10
1314
+ ):
1315
+ if str(execution.id) == current_blessed_exec_id:
1316
+ # Has the execution seen task "pipeline_card" which
1317
+ # completed successfully
1318
+ # ("execution" has generated a custom pipeline-card artifact) ?
1319
+ # If not, hyperlink generation will later fail.
1320
+ run_has_custom_card_artifact = (len([
1321
+ t for t in execution.get_tasks_with_name(
1322
+ task_type_name="pipeline_card")
1323
+ if t.end_timestamp and t.success
1324
+ ]) == 1)
1325
+ if not run_has_custom_card_artifact:
1326
+ print(
1327
+ f"Execution #{current_blessed_exec_id} " +
1328
+ "Doesn't seem to have successfully " +
1329
+ "generated a pipeline-card artifact.",
1330
+ file=sys.stderr, flush=True)
1331
+
1332
+ else:
1333
+ # further filtering on successful executions that are
1334
+ # retraining of a prior version of the same model
1335
+ # (to minimize the risk that this was obtained
1336
+ # on another DAG-engine instance)
1337
+ if (
1338
+ execution.get_attr("model_version_blessed") and
1339
+ execution.get_attr("model_repo_id") or "" == \
1340
+ ctx.model_repo_id
1341
+ ):
1342
+ ctx.current_blessed_exec = execution
1343
+
1344
+ break
1345
+
1346
+ if not ctx.current_blessed_exec:
1347
+ logger.warning(
1348
+ "Couldn't find blessed execution " +
1349
+ f"{current_blessed_exec_id} !\n" +
1350
+ "It seems that prior blessed execution was " +
1351
+ "executed on another DAG-engine instance.")
1352
+ else:
1353
+ logger.debug(
1354
+ f"ctx.current_blessed_exec : {ctx.current_blessed_exec}")
1355
+
1356
+ print("new : " +
1357
+ str(ctx.perf_metrics[main_perf_metric_name]) +
1358
+ " - previous best : " +
1359
+ str(current_blessed_metric_value) +
1360
+ " - model_version_blessing : " +
1361
+ str(ctx.model_version_blessed))
1362
+
1363
+ else:
1364
+ raise Exception(
1365
+ "Performance metric '" +
1366
+ main_perf_metric_name +
1367
+ "' can't be found in eval results " +
1368
+ "from blessed execution " +
1369
+ str(current_blessed_version_dict[
1370
+ "exec_id"]) + " !")
1371
+
1372
+ return None
1373
+
1374
+
1375
+ @task(ui_css=UiCss(background="#FF9900", color="#111827", border="#1F2937"))
1376
+ def model_to_hub(_) -> None:
1377
+ """
1378
+ Push to hub model version, including
1379
+ readme with versioning info.
1380
+ """
1381
+
1382
+ #############################
1383
+ # case of user-provided #
1384
+ # documentation artifact(s) #
1385
+ #############################
1386
+ # note that user can provide either
1387
+ # 'pipeline_card.py' or 'template.html'
1388
+ # or 'dataset_readme.py'
1389
+ # or 'dataset_readme_template.md'
1390
+ # or 'model_readme.py'
1391
+ # or 'model_readme_template.md'
1392
+ # or any combination of those
1393
+ # when specifying custom
1394
+ # 'pipeline_card_artifacts_path'
1395
+ if (
1396
+ "model_readme_template.md" in
1397
+ os.listdir(ctx.pipeline_card_artifacts_path)
1398
+ ):
1399
+ template_dir = ctx.pipeline_card_artifacts_path
1400
+ else:
1401
+ template_dir = os.path.dirname(
1402
+ importlib.util.find_spec(
1403
+ f"retrain_pipelines.pipeline_card."+
1404
+ f"{os.getenv('retrain_pipeline_type')}"
1405
+ ).origin)
1406
+ print(f"template_dir : '{template_dir}'")
1407
+ #############################
1408
+ if "model_readme.py" in os.listdir(
1409
+ ctx.pipeline_card_artifacts_path):
1410
+ from retrain_pipelines.utils import \
1411
+ get_get_model_readme_content
1412
+ get_model_readme_content = \
1413
+ get_get_model_readme_content(
1414
+ ctx.pipeline_card_artifacts_path)
1415
+ else:
1416
+ from retrain_pipelines.pipeline_card import \
1417
+ get_model_readme_content
1418
+ #############################
1419
+ from retrain_pipelines.model.hf_utils import \
1420
+ push_model_version_to_hub
1421
+
1422
+ #############################
1423
+ # model README #
1424
+ # from template #
1425
+ #############################
1426
+ commit_datetime = datetime.utcnow()
1427
+ new_model_version_label = get_new_repo_minor_version(
1428
+ repo_id=ctx.model_repo_id,
1429
+ repo_type="model",
1430
+ hf_token=os.getenv("HF_TOKEN", None))
1431
+ readme_content = get_model_readme_content(
1432
+ template_folder=template_dir,
1433
+
1434
+ model_repo_id=ctx.model_repo_id,
1435
+
1436
+ base_model_dict=ctx.hf_base_model_dict,
1437
+ training_dataset_dict=ctx.dataset_commit_dict,
1438
+
1439
+ version_label=new_model_version_label,
1440
+ commit_datetime=commit_datetime,
1441
+ perf_metrics=ctx.perf_metrics,
1442
+
1443
+ pipeline_name=ctx.pipeline_name,
1444
+ exec_id=ctx.exec_id
1445
+ )
1446
+ #############################
1447
+
1448
+ print("Pushing model version to HF hub " +
1449
+ ("(blessed). " if ctx.model_version_blessed
1450
+ else "(not blessed). ") +
1451
+ "May take a while..",
1452
+ flush=True)
1453
+ model_commit_hash = push_model_version_to_hub(
1454
+ repo_id=ctx.model_repo_id,
1455
+ model_version_blessed=\
1456
+ ctx.model_version_blessed,
1457
+ version_label=new_model_version_label,
1458
+ timestamp_str=commit_datetime.strftime(
1459
+ "%Y-%m-%d %H:%M:%S UTC"),
1460
+ model_dir=ctx.sft_model_dir,
1461
+ model_readme_content=readme_content,
1462
+ hf_token=os.getenv("HF_TOKEN", None)
1463
+ )
1464
+ if not model_commit_hash:
1465
+ raise Exception(
1466
+ "Failed to publish model version.")
1467
+ print("Push of model version to HF hub completed.",
1468
+ flush=True)
1469
+ print(f"https://huggingface.co/{ctx.model_repo_id}" +
1470
+ f"/blob/{model_commit_hash}/README.md")
1471
+
1472
+ ctx.model_commit_dict = {
1473
+ "repo_id": ctx.model_repo_id,
1474
+ "commit_hash": model_commit_hash,
1475
+ "version_label": new_model_version_label,
1476
+ "commit_datetime": commit_datetime,
1477
+ }
1478
+
1479
+ return None
1480
+
1481
+
1482
+ @task
1483
+ def infra_validator(_) -> None:
1484
+ """
1485
+ If the trained model version is blessed,
1486
+ validate serving.
1487
+ """
1488
+ """
1489
+ Note that using isolated virtual env
1490
+ (using @conda task decorator)
1491
+ is advisable to not embark the whole
1492
+ pipeline dependencies into the local server.
1493
+ We don't for educational purpose,
1494
+ keep things "simple" to grasp
1495
+ as well as to avoid forcing conda
1496
+ (for instance miniconda) as
1497
+ a virtual environment management mean
1498
+ to the user.
1499
+ """
1500
+ """
1501
+ Note : We load base model from HF-cache
1502
+ (mounted as /huggingface_hub_cache
1503
+ docker volume) and adapter from local dir
1504
+ (mounted as /FuncCallAdater docker volume.
1505
+ """
1506
+
1507
+ ctx.local_serve_is_ready = LocalServeReadinessEnum.NOT_APPLICABLE
1508
+
1509
+ if ctx.model_version_blessed:
1510
+ from retrain_pipelines.utils.docker import \
1511
+ env_has_docker
1512
+
1513
+ if env_has_docker():
1514
+ model_module_dir = \
1515
+ os.path.dirname(
1516
+ importlib.util.find_spec(
1517
+ "retrain_pipelines.model." +
1518
+ os.getenv('retrain_pipeline_type')
1519
+ ).origin)
1520
+
1521
+ # server & data-model & server-config modules artifacts
1522
+ files_to_copy = [
1523
+ "litserve_server.py",
1524
+ "litserve_datamodel.py",
1525
+ "litserve_serverconfig.py",
1526
+ ".dockerignore" # docker context loading
1527
+ # at image-build time,
1528
+ # exclude model weights
1529
+ ]
1530
+ for filename in files_to_copy:
1531
+ shutil.copy(
1532
+ os.path.join(model_module_dir, "litserve",
1533
+ filename),
1534
+ os.path.join(ctx.serving_artifacts_local_folder,
1535
+ filename)
1536
+ )
1537
+
1538
+ # save dependencies as artifact
1539
+ create_requirements(ctx.serving_artifacts_local_folder,
1540
+ exclude=["numpy", # version conflict
1541
+ # quick fix
1542
+ "cudf-polars-.*", "cuda-python",
1543
+ "nvidia-.*", "(py)?libcudf-.*",
1544
+ "nvtx", "rmm-.*", "litserve",
1545
+ "protobuf", "grpc.*",
1546
+ "tensorboard",
1547
+ ".*retrain-pipelines.*"]
1548
+ )
1549
+
1550
+ # server config yaml
1551
+ env = Environment(loader=FileSystemLoader(
1552
+ os.path.join(model_module_dir, "litserve")))
1553
+ template = env.get_template(
1554
+ "litserve_serverconfig_template.yaml")
1555
+ server_config_data = {
1556
+ "port": "8000",
1557
+ "max_seq_length": ctx.max_seq_length,
1558
+ "max_new_token": ctx.max_new_tokens,
1559
+ "base_model": {
1560
+ "repo_id": ctx.hf_base_model_dict["repo_id"],
1561
+ "revision": ctx.hf_base_model_dict["commit_hash"]
1562
+ },
1563
+ "adapters": [
1564
+ {
1565
+ "name": "func_caller",
1566
+ "path": "/FuncCallAdapter"
1567
+ }
1568
+ ]
1569
+ }
1570
+ server_config_yaml = template.render(server_config_data)
1571
+ print(server_config_yaml)
1572
+ with open(os.path.join(
1573
+ ctx.serving_artifacts_local_folder,
1574
+ "litserve_serverconfig.yaml"), 'w'
1575
+ ) as output_file:
1576
+ output_file.write(server_config_yaml)
1577
+
1578
+ # Dockerfile
1579
+ env = Environment(loader=FileSystemLoader(
1580
+ os.path.join(model_module_dir)))
1581
+ template = env.get_template(
1582
+ "Dockerfile.litserve_template")
1583
+ # Change CUDA version here from available list
1584
+ # @see https://hub.docker.com/r/nvidia/cuda/tags
1585
+ dockerfile_content = template.render(
1586
+ {"cuda_version": "12.0.0"})
1587
+ with open(os.path.join(
1588
+ ctx.serving_artifacts_local_folder,
1589
+ "Dockerfile.litserve"), 'w'
1590
+ ) as output_file:
1591
+ output_file.write(dockerfile_content)
1592
+
1593
+ os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
1594
+
1595
+ ############################################
1596
+ # actually deploy the inference service #
1597
+ ############################################
1598
+ start_time = time.time()
1599
+ from retrain_pipelines.utils.docker import \
1600
+ build_and_run_docker, print_container_log_tail, \
1601
+ cleanup_docker
1602
+ from retrain_pipelines.model.litserve import \
1603
+ endpoint_started, endpoint_is_ready
1604
+
1605
+ ctx.port = 8765
1606
+ HF_HUB_CACHE = os.path.realpath(os.path.expanduser(
1607
+ os.getenv(
1608
+ "HF_HUB_CACHE",
1609
+ os.path.join(os.getenv("HF_HOME",
1610
+ "~/.cache/huggingface"),
1611
+ "hub")
1612
+ )))
1613
+ print(f"HF_HUB_CACHE : {HF_HUB_CACHE}")
1614
+ image_name = container_name = "litserve-model"
1615
+
1616
+ serving_container = build_and_run_docker(
1617
+ image_name=image_name, image_tag="1.0",
1618
+ build_path=ctx.serving_artifacts_local_folder,
1619
+ dockerfile="Dockerfile.litserve",
1620
+ ports_publish_dict={'8000/tcp': ctx.port},
1621
+ env_vars_dict={
1622
+ "HF_HUB_CACHE": "/huggingface_hub_cache",
1623
+ "HF_TOKEN": os.getenv("HF_TOKEN")
1624
+ },
1625
+ volumes_dict={
1626
+ ctx.sft_model_dir:
1627
+ {"bind": "/FuncCallAdapter",
1628
+ "mode": "ro"},
1629
+ HF_HUB_CACHE:
1630
+ {"bind": "/huggingface_hub_cache",
1631
+ "mode": "ro"}
1632
+ }
1633
+ )
1634
+
1635
+ if not serving_container:
1636
+ print("failed spinning the LitServe container",
1637
+ file=sys.stderr)
1638
+ ctx.local_serve_is_ready = \
1639
+ LocalServeReadinessEnum.FAILURE
1640
+ try:
1641
+ cleanup_docker(
1642
+ container_name=container_name,
1643
+ image_name=f"{image_name}:1.0",
1644
+ no_pruning=True # for intermediate layers recycling
1645
+ # (during later re-runs)
1646
+ # to avoid long rebuild time
1647
+ # of exactly the same.
1648
+ )
1649
+ except Exception as cleanup_ex:
1650
+ # fail silently
1651
+ pass
1652
+ else:
1653
+ print("Awaiting endpoint launch..")
1654
+ start_time = time.time()
1655
+ if not endpoint_started(
1656
+ container_name, port=ctx.port, timeout=10*60
1657
+ ):
1658
+ print(
1659
+ f"The endpoint '{container_name}' " +
1660
+ f"did not start.")
1661
+ ctx.local_serve_is_ready = \
1662
+ LocalServeReadinessEnum.FAILURE
1663
+ # health check on the spun-up endpoint
1664
+ elif endpoint_is_ready(port=ctx.port):
1665
+ ctx.local_serve_is_ready = \
1666
+ LocalServeReadinessEnum.SUCCESS
1667
+ elapsed_time = time.time() - start_time
1668
+ print("deploy_local - Elapsed time: " +
1669
+ f"{elapsed_time:.2f} seconds")
1670
+ ############################################
1671
+ else:
1672
+ # env doesn't have docker
1673
+ ctx.local_serve_is_ready = \
1674
+ LocalServeReadinessEnum.FAILURE_NO_DOCKER
1675
+
1676
+ if LocalServeReadinessEnum.SUCCESS == ctx.local_serve_is_ready:
1677
+ from retrain_pipelines.model.litserve.litserve_datamodel \
1678
+ import Response
1679
+
1680
+ import requests
1681
+
1682
+ url = f"http://localhost:{ctx.port}/predict"
1683
+ headers = {"accept": "application/x-www-form-urlencoded"}
1684
+
1685
+ try:
1686
+ start_time = time.time()
1687
+ data = {
1688
+ "adapter_name": "func_caller",
1689
+ "queries_list": '["Hello.", "Is 49 a perfect square?"]'
1690
+ }
1691
+ print(f"inference test - data: {data}")
1692
+ response = requests.post(url, headers=headers, data=data)
1693
+ parsed_response = Response(**{"output": response.json()})
1694
+ elapsed_time = time.time() - start_time
1695
+ print("parsed_response ('func_caller' adapter ON) :" +
1696
+ str(parsed_response) +
1697
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1698
+
1699
+ start_time = time.time()
1700
+ data = {
1701
+ "queries_list": '["Hello.", "Is 49 a perfect square?"]'
1702
+ }
1703
+ print(f"inference test - data: {data}")
1704
+ response = requests.post(url, headers=headers, data=data)
1705
+ parsed_response = Response(**{"output": response.json()})
1706
+ elapsed_time = time.time() - start_time
1707
+ print(f"parsed_response (no adapter) : {parsed_response}" +
1708
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1709
+
1710
+ except Exception as ex:
1711
+ print(ex, file=sys.stderr)
1712
+ traceback.print_tb(ex.__traceback__, file=sys.stderr)
1713
+ ctx.local_serve_is_ready = \
1714
+ LocalServeReadinessEnum.FAILURE
1715
+ pass
1716
+
1717
+ try:
1718
+ cleanup_docker(
1719
+ container_name=container_name,
1720
+ image_name=f"{image_name}:1.0",
1721
+ no_pruning=True # for intermediate layers recycling
1722
+ # (during later re-runs)
1723
+ # to avoid long rebuild time
1724
+ # of exactly the same.
1725
+ )
1726
+ except Exception as cleanup_ex:
1727
+ # fail silently
1728
+ pass
1729
+ else:
1730
+ logger.info("model-version not blessed - skipping")
1731
+
1732
+ return None
1733
+
1734
+
1735
+ @task
1736
+ def pipeline_card(_, task_id: int) -> None:
1737
+ #############################
1738
+ # case of user-provided #
1739
+ # documentation artifact(s) #
1740
+ #############################
1741
+ # note that user can provide either
1742
+ # 'pipeline_card.py' or 'template.html'
1743
+ # or 'dataset_readme.py'
1744
+ # or 'dataset_readme_template.md'
1745
+ # or 'model_readme.py'
1746
+ # or 'model_readme_template.md'
1747
+ # or any combination of those
1748
+ # when specifying custom
1749
+ # 'pipeline_card_artifacts_path'
1750
+ if "template.html" in os.listdir(
1751
+ ctx.pipeline_card_artifacts_path
1752
+ ):
1753
+ template_dir = ctx.pipeline_card_artifacts_path
1754
+ else:
1755
+ template_dir = os.path.dirname(
1756
+ importlib.util.find_spec(
1757
+ f"retrain_pipelines.pipeline_card."+
1758
+ f"{os.getenv('retrain_pipeline_type')}"
1759
+ ).origin)
1760
+ #############################
1761
+ if "pipeline_card.py" in os.listdir(
1762
+ ctx.pipeline_card_artifacts_path
1763
+ ):
1764
+ from retrain_pipelines.utils import get_get_html
1765
+ get_html = \
1766
+ get_get_html(ctx.pipeline_card_artifacts_path)
1767
+ else:
1768
+ from retrain_pipelines.pipeline_card import \
1769
+ get_html
1770
+ from retrain_pipelines.dag_engine.renderer import dag_svg
1771
+ #############################
1772
+
1773
+ #############################
1774
+ ## html "custom" card ##
1775
+ #############################
1776
+ dt = datetime.now(tz=timezone.utc)
1777
+ formatted_dt = dt.strftime("%A %b %d %Y %I:%M:%S %p %Z")
1778
+ task_obj_python_cmd = f"sdk.Task({task_id})"
1779
+ executions_count = ExecutionsIterator(
1780
+ exec_name=ctx.pipeline_name).length()
1781
+
1782
+ params={
1783
+ 'template_dir': template_dir,
1784
+ 'title': ctx.pipeline_name,
1785
+ "subtitle": f"(Pipeline execution # {executions_count}," + \
1786
+ f" exec_id: {str(ctx.exec_id)} - {formatted_dt})",
1787
+
1788
+ # blessed status / current_blessed version
1789
+ 'model_version_blessed': ctx.model_version_blessed,
1790
+ 'current_blessed_version_label': (
1791
+ ctx.current_blessed_version_dict["version_label"]
1792
+ if ctx.current_blessed_version_dict
1793
+ else None
1794
+ ),
1795
+ 'current_blessed_commit_datetime': (
1796
+ ctx.current_blessed_version_dict["commit_datetime"]
1797
+ if ctx.current_blessed_version_dict
1798
+ else None
1799
+ ),
1800
+ 'current_blessed_model_commit_hash': (
1801
+ ctx.current_blessed_version_dict["commit_hash"]
1802
+ if ctx.current_blessed_version_dict
1803
+ else None
1804
+ ),
1805
+ 'current_blessed_run': ctx.current_blessed_exec,
1806
+
1807
+ 'LocalServeReadinessEnum': LocalServeReadinessEnum,
1808
+ 'local_serve_is_ready': ctx.local_serve_is_ready,
1809
+ # EDA
1810
+ 'main_dataset_repo_id': ctx.hf_dataset['repo_id'],
1811
+ 'main_dataset_commit_hash': ctx.hf_dataset_dict['commit_hash'],
1812
+ 'main_dataset_commit_datetime': \
1813
+ ctx.hf_dataset_dict['commit_datetime'],
1814
+
1815
+ 'records_count': ctx.records_count,
1816
+ 'data_schema': ctx.data_schema,
1817
+ 'answers_tools_count_fig': ctx.answers_tools_count_fig,
1818
+ 'words_count_fig': ctx.words_count_fig,
1819
+
1820
+ # model training
1821
+ 'dataset_repo_id': ctx.dataset_repo_id,
1822
+ 'dataset_version_label': ctx.dataset_commit_dict["version_label"],
1823
+ 'dataset_commit_datetime': ctx.dataset_commit_dict["commit_datetime"],
1824
+ 'dataset_commit_hash': ctx.dataset_commit_dict["commit_hash"],
1825
+ 'dataset_augmentation_rate': ctx.actual_augmentation_rate,
1826
+ 'dataset_enrichment_rate': ctx.enrichment_rate,
1827
+
1828
+ # trained model version
1829
+ 'model_repo_id': ctx.model_repo_id,
1830
+ 'model_version_label': ctx.model_commit_dict["version_label"],
1831
+ 'model_commit_datetime': ctx.model_commit_dict["commit_datetime"],
1832
+ 'model_commit_hash': ctx.model_commit_dict["commit_hash"],
1833
+
1834
+ 'cpt_log_history_fig': ctx.cpt_log_history_fig,
1835
+ 'sft_log_history_fig': ctx.sft_log_history_fig,
1836
+
1837
+ 'validation_completions_fig': ctx.validation_completions_fig,
1838
+
1839
+ 'hf_base_model_dict': ctx.hf_base_model_dict,
1840
+ 'pipeline_parameters_dict': {"cpt": ctx.cpt_training_args,
1841
+ "sft": ctx.sft_training_args},
1842
+
1843
+ 'metrics_dict': ctx.perf_metrics,
1844
+
1845
+ 'task_obj_python_cmd': task_obj_python_cmd,
1846
+ 'dag_svg': dag_svg(execution_id=ctx.exec_id)
1847
+ }
1848
+ html = get_html(params)
1849
+
1850
+ filename = os.path.join(
1851
+ os.environ["RP_ARTIFACTS_STORE"],
1852
+ ctx.pipeline_name, str(ctx.exec_id),
1853
+ "pipeline_card.html"
1854
+ )
1855
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
1856
+ with open(filename, "w", encoding="utf-8") as file:
1857
+ file.write(html)
1858
+ logger.debug(
1859
+ "pipeline_card - " +
1860
+ f"[bold]pipeline_card_file_fullname : {filename}[/]")
1861
+
1862
+ ctx.pipeline_card_file_fullname = filename
1863
+ #############################
1864
+
1865
+ return None
1866
+
1867
+
1868
+ @task(ui_css=UiCss(background="#FF9900", color="#111827", border="#1F2937"))
1869
+ def pipeline_to_hub(_) -> None:
1870
+ """
1871
+ publish versioned source-code and pipeline-card
1872
+ for ths run on the Hugging Face Hub.
1873
+ """
1874
+ model_commit_datetime = \
1875
+ ctx.model_commit_dict["commit_datetime"]
1876
+ timestamp_str = \
1877
+ "{:%Y%m%d_%H%M%S}".format(model_commit_datetime) + \
1878
+ "{:03d}".format(model_commit_datetime.microsecond//1000) + \
1879
+ "_UTC"
1880
+ subfolder_name = \
1881
+ "v" + ctx.model_commit_dict["version_label"] + \
1882
+ "_" + timestamp_str
1883
+ commit_datetime = datetime.utcnow()
1884
+
1885
+ ###############################
1886
+ # source-code #
1887
+ ###############################
1888
+ # We upload only herein file #
1889
+ # plus user-provided versions #
1890
+ # of the customizable ones #
1891
+ # (if any). #
1892
+ ###############################
1893
+ custom_source_files = [os.path.abspath(__file__)]
1894
+ if (
1895
+ ctx.pipeline_card_artifacts_path != \
1896
+ ctx.params_definitions["pipeline_card_artifacts_path"].default
1897
+ ):
1898
+ candidate_source_files = [
1899
+ "pipeline_card.py",
1900
+ "template.html",
1901
+ "dataset_readme.py",
1902
+ "dataset_readme_template.md",
1903
+ "model_readme.py",
1904
+ "model_readme_template.md"
1905
+ ]
1906
+ for candidate_source_file in candidate_source_files:
1907
+ file_fullpath = os.path.join(
1908
+ ctx.pipeline_card_artifacts_path,
1909
+ candidate_source_file)
1910
+ if os.path.exists(file_fullpath):
1911
+ custom_source_files.append(file_fullpath)
1912
+
1913
+ source_code_commit_hash = \
1914
+ push_files_to_hub_repo_branch(
1915
+ repo_id=ctx.model_repo_id,
1916
+ branch_name="retrain-pipelines_source-code",
1917
+ file_fullnames=custom_source_files,
1918
+ include_requirements_txt=True,
1919
+ path_in_repo=subfolder_name,
1920
+ commit_message=\
1921
+ "source-code for model version " + \
1922
+ subfolder_name + \
1923
+ f"- retrain-pipelines {__version__}",
1924
+ repo_type="model",
1925
+ hf_token=os.getenv("HF_TOKEN", None)
1926
+ )
1927
+ print(source_code_commit_hash)
1928
+ ctx.source_code_commit_dict = {
1929
+ "repo_id": ctx.model_repo_id,
1930
+ "branch_name": "retrain-pipelines_source-code",
1931
+ "commit_datetime": commit_datetime,
1932
+ "commit_hash": source_code_commit_hash
1933
+ }
1934
+ ###############################
1935
+
1936
+ ###############################
1937
+ # pipeline-card #
1938
+ ###############################
1939
+ pipeline_card_commit_hash = \
1940
+ push_files_to_hub_repo_branch(
1941
+ repo_id=ctx.model_repo_id,
1942
+ branch_name="retrain-pipelines_pipeline-card",
1943
+ file_fullnames=[ctx.pipeline_card_file_fullname],
1944
+ path_in_repo=subfolder_name,
1945
+ commit_message=\
1946
+ "pipeline-card for model version " + \
1947
+ subfolder_name + \
1948
+ f"- retrain-pipelines {__version__}",
1949
+ repo_type="model",
1950
+ hf_token=os.getenv("HF_TOKEN", None)
1951
+ )
1952
+ print(pipeline_card_commit_hash)
1953
+ ctx.pipeline_card_commit_dict = {
1954
+ "repo_id": ctx.model_repo_id,
1955
+ "branch_name": "retrain-pipelines_pipeline-card",
1956
+ "commit_datetime": commit_datetime,
1957
+ "commit_hash": pipeline_card_commit_hash
1958
+ }
1959
+ ###############################
1960
+
1961
+ return None
1962
+
1963
+
1964
+ @task
1965
+ def deploy(_):
1966
+ """
1967
+ placeholder for the serving SDK "deploy" call
1968
+ (on the target production platform).
1969
+ consider including the portable pipelione-card
1970
+ itself to the inference service endpoint !
1971
+ """
1972
+
1973
+ if ctx.model_version_blessed and (ctx.local_serve_is_ready == 1):
1974
+ pass # your code here
1975
+
1976
+ return None
1977
+
1978
+
1979
+ @task
1980
+ def load_test(_):
1981
+ """
1982
+ placeholder
1983
+ """
1984
+
1985
+ if ctx.model_version_blessed and (ctx.local_serve_is_ready == 1):
1986
+ pass # your code here
1987
+
1988
+ return None
1989
+
1990
+
1991
+ @task
1992
+ def end(_):
1993
+ pass
1994
+
1995
+
1996
+ #--- retraining-pipeline params & DAG ---------------------------------------------------
1997
+
1998
+
1999
+ @dag(ui_css=UiCss(color="#FFDD00", background="#7AD4FF", border="#C28E00"))
2000
+ def retrain_pipeline():
2001
+ """
2002
+ Retraining pipeline with SFT & CPT. Small LLM with pluggable adapter specialized in tool-calling from intrinsic knowledge bank of tools and not from extended context. Model-version blessing. Serving via a custom LitServe toy-server.
2003
+ """
2004
+ # @see https://github.com/unslothai/unsloth/wiki
2005
+
2006
+ #--- flow parameters -------------------------------------------------------
2007
+
2008
+
2009
+ RETRAIN_PIPELINE_TYPE = "mf_unsloth_func_call_litserve"
2010
+ # best way to share the config across subprocesses
2011
+ os.environ["retrain_pipeline_type"] = RETRAIN_PIPELINE_TYPE
2012
+
2013
+ hf_dataset = DagParam(
2014
+ description="dict with 'repo_id' and 'commit_hash' keys. " + \
2015
+ "if 'commit_hash is None, falls back to latest version " +\
2016
+ "of the dataset available in parquet format.\n" +
2017
+ "Note that there are 3 required 'attributes' of type " + \
2018
+ "str, list[str], list[str]",
2019
+ default=dedent("""{
2020
+ "repo_id": "Salesforce/xlam-function-calling-60k",
2021
+ "config_name": "",
2022
+ "commit_hash": "",
2023
+ "attributes": {
2024
+ "query_attr": "query",
2025
+ "answers_attr": "answers",
2026
+ "tools_attr": "tools"
2027
+ }
2028
+ }""")
2029
+ )
2030
+
2031
+ augmentation_rate = DagParam(
2032
+ description="(float) proportion of records to be augmented " + \
2033
+ "(x% of original dataset is created" + \
2034
+ " as additional augmented datapoints), i.e. " + \
2035
+ "truncated queries to serve as negative examples, " + \
2036
+ "meaning they trigger no tool call " + \
2037
+ "due to info incompleteness.",
2038
+ default=.05
2039
+ )
2040
+
2041
+ hf_enrich_dataset = DagParam(
2042
+ description="dict with 'repo_id', 'config_name' and 'commit_hash', " + \
2043
+ "query_attribute' and 'query_attribute_handler' keys. " + \
2044
+ "if 'commit_hash is None, falls back to latest version " + \
2045
+ "of the dataset available in parquet format." + \
2046
+ "'query_attribute' depicts the dataset attribute " + \
2047
+ "from which 'queries' are to be sampled." + \
2048
+ "'query_attribute_handler' serves for attributes " + \
2049
+ "that have complex structure, " + \
2050
+ "other than 'string' datatype.",
2051
+ # @see https://huggingface.co/datasets/google-research-datasets/natural_questions
2052
+ default=dedent("""{
2053
+ "repo_id": "lighteval/natural_questions_clean",
2054
+ "config_name": "",
2055
+ "commit_hash": "",
2056
+ "query_attribute": "question",
2057
+ "query_attribute_handler": "lambda x: x"
2058
+ }""")
2059
+ )
2060
+
2061
+ enrichment_rate = DagParam(
2062
+ description="(float) proportion of records " + \
2063
+ "to be added from the 'hf_enrich_dataset'" + \
2064
+ "(x% of original dataset is sampled and" + \
2065
+ " added as enriching datapoints), i.e. " + \
2066
+ "queries to serve as negative examples, " + \
2067
+ "due to their complete disconnexion " + \
2068
+ "to tool calling situations.",
2069
+ default=.1
2070
+ )
2071
+
2072
+ dataset_repo_id = DagParam(
2073
+ description="(str) The 'repo_id' to be used " + \
2074
+ "for the Hugging Face dataset version push " + \
2075
+ "(will be created at runtime" + \
2076
+ " if doesn't already exist).",
2077
+ default="retrain-pipelines/func_calls"
2078
+ )
2079
+
2080
+ polars_engine = DagParam(
2081
+ description="The engine used by Polars for " + \
2082
+ "dataset querying and processing " + \
2083
+ "(either 'gpu' or 'cpu').",
2084
+ default="gpu"
2085
+ )
2086
+
2087
+ hf_base_model = DagParam(
2088
+ description="(str) dict with 'repo_id' and 'commit_hash' keys." + \
2089
+ "if 'commit_hash is None, falls back " + \
2090
+ "to latest available version of the model.",
2091
+ default=dedent("""{
2092
+ "repo_id": "unsloth/Qwen2.5-1.5B",
2093
+ "commit_hash": ""
2094
+ }""")
2095
+ )
2096
+
2097
+ cpt_training_args = DagParam(
2098
+ description="dict with `TrainingArguments` params " + \
2099
+ "for the CPT job.",
2100
+ default=dedent("""{
2101
+ "warmup_ratio": 0.1,
2102
+ "num_train_epochs": 1
2103
+ }""")
2104
+ )
2105
+
2106
+ sft_training_args = DagParam(
2107
+ description="dict with `TrainingArguments` params " + \
2108
+ "for the SFT job.",
2109
+ default=dedent("""{
2110
+ "warmup_ratio": 0.1,
2111
+ "num_train_epochs": 1
2112
+ }""")
2113
+ )
2114
+
2115
+ model_repo_id = DagParam(
2116
+ description="(str) The 'repo_id' to be used " + \
2117
+ "for the Hugging Face model version push " + \
2118
+ "(will be created at runtime" + \
2119
+ " if doesn't already exist).",
2120
+ default="retrain-pipelines/function_caller"
2121
+ )
2122
+
2123
+ default_pipeline_card_module_dir = \
2124
+ os.path.dirname(
2125
+ importlib.util.find_spec(
2126
+ f"retrain_pipelines.pipeline_card."+
2127
+ f"{RETRAIN_PIPELINE_TYPE}"
2128
+ ).origin)
2129
+ pipeline_card_artifacts_path = DagParam(
2130
+ description="pipeline_card artifacts location " + \
2131
+ "(i.e. dir hosting your optional " + \
2132
+ " custom documentation files :" + \
2133
+ " 'pipeline_card.py' and/or 'template.html'" + \
2134
+ " and/or 'model_readme.py'"+\
2135
+ " and/or 'model_readme_template.md'," + \
2136
+ " and/or 'dataset_readme.py'" + \
2137
+ " and/or 'dataset_readme_template.md' file), " + \
2138
+ "if different from default.",
2139
+ default=default_pipeline_card_module_dir
2140
+ )
2141
+ # TODO - convert from class method to TBD
2142
+ # @staticmethod
2143
+ # def copy_default_dataset_readme_module(
2144
+ # target_dir: str,
2145
+ # exists_ok: bool = False
2146
+ # ) -> None:
2147
+ # os.makedirs(target_dir, exist_ok=True)
2148
+ # if (
2149
+ # not exists_ok and
2150
+ # os.path.exists(os.path.join(target_dir, "dataset_readme.py"))
2151
+ # ):
2152
+ # print("File already exists. Skipping copy.")
2153
+ # else:
2154
+ # filefullname = os.path.join(
2155
+ # default_pipeline_card_module_dir,
2156
+ # "dataset_readme.py"
2157
+ # )
2158
+ # shutil.copy(filefullname, target_dir)
2159
+ # print(filefullname)
2160
+ # TODO - convert from class method to TBD
2161
+ # @staticmethod
2162
+ # def copy_default_dataset_readme_template(
2163
+ # target_dir: str,
2164
+ # exists_ok: bool = False
2165
+ # ) -> None:
2166
+ # os.makedirs(target_dir, exist_ok=True)
2167
+ # if (
2168
+ # not exists_ok and
2169
+ # os.path.exists(os.path.join(target_dir,
2170
+ # "dataset_readme_template.md"))
2171
+ # ):
2172
+ # print("File already exists. Skipping copy.")
2173
+ # else:
2174
+ # filefullname = os.path.join(
2175
+ # default_pipeline_card_module_dir,
2176
+ # "dataset_readme_template.md")
2177
+ # shutil.copy(filefullname, target_dir)
2178
+ # print(filefullname)
2179
+ # TODO - convert from class method to TBD
2180
+ # @staticmethod
2181
+ # def copy_default_model_readme_module(
2182
+ # target_dir: str,
2183
+ # exists_ok: bool = False
2184
+ # ) -> None:
2185
+ # os.makedirs(target_dir, exist_ok=True)
2186
+ # if (
2187
+ # not exists_ok and
2188
+ # os.path.exists(os.path.join(target_dir, "model_readme.py"))
2189
+ # ):
2190
+ # print("File already exists. Skipping copy.")
2191
+ # else:
2192
+ # filefullname = os.path.join(
2193
+ # default_pipeline_card_module_dir,
2194
+ # "model_readme.py"
2195
+ # )
2196
+ # shutil.copy(filefullname, target_dir)
2197
+ # print(filefullname)
2198
+ # TODO - convert from class method to TBD
2199
+ # @staticmethod
2200
+ # def copy_default_model_readme_template(
2201
+ # target_dir: str,
2202
+ # exists_ok: bool = False
2203
+ # ) -> None:
2204
+ # os.makedirs(target_dir, exist_ok=True)
2205
+ # if (
2206
+ # not exists_ok and
2207
+ # os.path.exists(os.path.join(target_dir,
2208
+ # "model_readme_template.md"))
2209
+ # ):
2210
+ # print("File already exists. Skipping copy.")
2211
+ # else:
2212
+ # filefullname = os.path.join(
2213
+ # default_pipeline_card_module_dir,
2214
+ # "model_readme_template.md")
2215
+ # shutil.copy(filefullname, target_dir)
2216
+ # print(filefullname)
2217
+ # TODO - convert from class method to TBD
2218
+ # @staticmethod
2219
+ # def copy_default_pipeline_card_module(
2220
+ # target_dir: str,
2221
+ # exists_ok: bool = False
2222
+ # ) -> None:
2223
+ # os.makedirs(target_dir, exist_ok=True)
2224
+ # if (
2225
+ # not exists_ok and
2226
+ # os.path.exists(os.path.join(target_dir, "pipeline_card.py"))
2227
+ # ):
2228
+ # print("File already exists. Skipping copy.")
2229
+ # else:
2230
+ # filefullname = os.path.join(
2231
+ # default_pipeline_card_module_dir,
2232
+ # "pipeline_card.py"
2233
+ # )
2234
+ # shutil.copy(filefullname, target_dir)
2235
+ # print(filefullname)
2236
+ # TODO - convert from class method to TBD
2237
+ # @staticmethod
2238
+ # def copy_default_pipeline_card_html_template(
2239
+ # target_dir: str,
2240
+ # exists_ok: bool = False
2241
+ # ) -> None:
2242
+ # os.makedirs(target_dir, exist_ok=True)
2243
+ # if (
2244
+ # not exists_ok and
2245
+ # os.path.exists(os.path.join(target_dir, "template.html"))
2246
+ # ):
2247
+ # print("File already exists. Skipping copy.")
2248
+ # else:
2249
+ # filefullname = os.path.join(
2250
+ # default_pipeline_card_module_dir,
2251
+ # "template.html")
2252
+ # shutil.copy(filefullname, target_dir)
2253
+ # print(filefullname)
2254
+
2255
+ del RETRAIN_PIPELINE_TYPE
2256
+
2257
+ #---------------------------------------------------------------------------
2258
+
2259
+ return start >> eda \
2260
+ >> augment_data >> enrich_data >> dataset_to_hub \
2261
+ >> continued_pre_training >> supervised_finetuning \
2262
+ >> evaluate_model >> model_version_blessing \
2263
+ >> model_to_hub >> infra_validator >> pipeline_card \
2264
+ >> pipeline_to_hub >> deploy >> load_test >> end
2265
+