autolabel icon indicating copy to clipboard operation
autolabel copied to clipboard

[Feature Request]: Support for Azure OpenAI as a provider

Open rishabh-bhargava opened this issue 1 year ago • 3 comments

Is your feature request related to a problem? Please describe. We would like to be able to use the OpenAI models through Azure's OpenAI offering.

Describe the solution you'd like Support for Azure OpenAI as a provider.

rishabh-bhargava avatar Jul 03 '23 15:07 rishabh-bhargava

We can use Azure OpenAI LLM implementation available in langchain for this: https://python.langchain.com/docs/modules/model_io/models/llms/integrations/azure_openai_example

The integration into Autolabel can be very similar to OpenAI: https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/models/openai.py (which in turn uses langchain's OpenAI LLM implentation)

A reference PR for recently added support for Cohere: https://github.com/refuel-ai/autolabel/pull/419

nihit avatar Jul 05 '23 06:07 nihit

@nihit @rishabh-bhargava @rajasbansal I want to pick this issue. Let me know if no one is working on this.

shril avatar Nov 01 '23 08:11 shril

We can use Azure OpenAI LLM implementation available in langchain for this: https://python.langchain.com/docs/modules/model_io/models/llms/integrations/azure_openai_example

The integration into Autolabel can be very similar to OpenAI: https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/models/openai.py (which in turn uses langchain's OpenAI LLM implentation)

A reference PR for recently added support for Cohere: #419 Hi @nihit Only change autolabel openai.py looks could not resolve the issue, because there is a embedding task before the LLM label. but after I try to change a lot of python file, still could not comunnicate to Azure emedding LLM,

Like below:

connection broken by 'NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f7df3fb0400>: Failed to establish a new connection: [Errno 101] Network is unreachable')': /v1/engines/text-embedding-ada-002/embeddings


KeyboardInterrupt Traceback (most recent call last) /aml/autolabel/examples/banking/example_banking.ipynb Cell 15 line 4 2 from autolabel import AutolabelDataset 3 ds = AutolabelDataset("test.csv", config=config) ----> 4 agent.plan(ds)

File /aml/train/lib/python3.8/site-packages/autolabel/labeler.py:389, in LabelingAgent.plan(self, dataset, max_items, start_index) 380 if ( 381 self.config.explanation_column() 382 and len(seed_examples) > 0 383 and self.config.explanation_column() not in list(seed_examples[0].keys()) 384 ): 385 raise ValueError( 386 f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file)." 387 ) --> 389 self.example_selector = ExampleSelectorFactory.initialize_selector( 390 self.config, 391 [safe_serialize_to_string(example) for example in seed_examples], 392 dataset.df.keys().tolist(), 393 cache=self.generation_cache is not None, 394 ) 396 if self.config.label_selection(): 397 if self.config.task_type() != TaskType.CLASSIFICATION:

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/init.py:118, in ExampleSelectorFactory.initialize_selector(config, examples, columns, cache) 112 if algorithm not in [ 113 FewShotAlgorithm.FIXED, 114 FewShotAlgorithm.LABEL_DIVERSITY_RANDOM, 115 ]: 116 params["cache"] = cache --> 118 return example_cls.from_examples(**params)

File /aml/train/lib/python3.8/site-packages/langchain/prompts/example_selector/semantic_similarity.py:96, in SemanticSimilarityExampleSelector.from_examples(cls, examples, embeddings, vectorstore_cls, k, input_keys, **vectorstore_cls_kwargs) 94 else: 95 string_examples = [" ".join(sorted_values(eg)) for eg in examples] ---> 96 vectorstore = vectorstore_cls.from_texts( 97 string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs 98 ) 99 return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:453, in VectorStoreWrapper.from_texts(cls, texts, embedding, metadatas, cache, **kwargs) 436 """Create a vectorstore from raw text. 437 The data will be ephemeral in-memory. 438 Args: ref='/aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:0'>0;32m (...) 444 vector_store: Vectorstore with seedset embeddings 445 """ 446 vector_store = cls( 447 embedding_function=embedding, 448 corpus_embeddings=None, ref='/aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:0'>0;32m (...) 451 **kwargs, 452 ) --> 453 vector_store.add_texts(texts=texts, metadatas=metadatas) 454 return vector_store

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:244, in VectorStoreWrapper.add_texts(self, texts, metadatas) 236 """Run texts through the embeddings and add to the vectorstore. Currently, the vectorstore is reinitialized each time, because we do not require a persistent vector store for example selection. 237 Args: 238 texts (Iterable[str]): Texts to add to the vectorstore. ref='/aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:0'>0;32m (...) 241 List[str]: List of IDs of the added texts. 242 """ 243 if self._embedding_function is not None: --> 244 embeddings = self._get_embeddings(texts) 246 self._corpus_embeddings = torch.tensor(embeddings) 247 self._texts = texts

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:195, in VectorStoreWrapper._get_embeddings(self, texts) 192 uncached_texts.append(text) 193 uncached_texts_indices.append(idx) --> 195 uncached_embeddings = self._embedding_function.embed_documents( 196 uncached_texts 197 ) 198 self._add_embeddings_to_cache(uncached_texts, uncached_embeddings) 199 for idx, embedding in zip(uncached_texts_indices, uncached_embeddings):

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:476, in OpenAIEmbeddings.embed_documents(self, texts, chunk_size) 464 """Call out to OpenAI's embedding endpoint for embedding search docs. 465 466 Args: ref='/aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:0'>0;32m (...) 472 List of embeddings, one for each text. 473 """ 474 # NOTE: to keep things simple, we assume the list may contain texts longer 475 # than the maximum context and use length-safe embedding function. --> 476 return self._get_len_safe_embeddings(texts, engine=self.deployment)

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:326, in OpenAIEmbeddings._get_len_safe_embeddings(self, texts, engine, chunk_size) 323 _iter = range(0, len(tokens), _chunk_size) 325 for i in _iter: --> 326 response = embed_with_retry( 327 self, 328 input=tokens[i : i + _chunk_size], 329 **self._invocation_params, 330 ) 331 batched_embeddings += [r["embedding"] for r in response["data"]] 333 results: List[List[List[float]]] = [[] for _ in range(len(texts))]

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:107, in embed_with_retry(embeddings, **kwargs) 104 response = embeddings.client.create(**kwargs) 105 return _check_response(response) --> 107 return _embed_with_retry(**kwargs)

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:289, in BaseRetrying.wraps..wrapped_f(*args, **kw) 287 @functools.wraps(f) 288 def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any: --> 289 return self(f, *args, **kw)

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:379, in Retrying.call(self, fn, *args, **kwargs) 377 retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) 378 while True: --> 379 do = self.iter(retry_state=retry_state) 380 if isinstance(do, DoAttempt): 381 try:

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:314, in BaseRetrying.iter(self, retry_state) 312 is_explicit_retry = fut.failed and isinstance(fut.exception(), TryAgain) 313 if not (is_explicit_retry or self.retry(retry_state)): --> 314 return fut.result() 316 if self.after is not None: 317 self.after(retry_state)

File /aml/train/lib/python3.8/concurrent/futures/_base.py:437, in Future.result(self, timeout) 435 raise CancelledError() 436 elif self._state == FINISHED: --> 437 return self.__get_result() 439 self._condition.wait(timeout) 441 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File /aml/train/lib/python3.8/concurrent/futures/_base.py:389, in Future.__get_result(self) 387 if self._exception: 388 try: --> 389 raise self._exception 390 finally: 391 # Break a reference cycle with the exception in self._exception 392 self = None

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:382, in Retrying.call(self, fn, *args, **kwargs) 380 if isinstance(do, DoAttempt): 381 try: --> 382 result = fn(*args, **kwargs) 383 except BaseException: # noqa: B902 384 retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:104, in embed_with_retry.._embed_with_retry(**kwargs) 102 @retry_decorator 103 def _embed_with_retry(**kwargs: Any) -> Any: --> 104 response = embeddings.client.create(**kwargs) 105 return _check_response(response)

File /aml/train/lib/python3.8/site-packages/openai/api_resources/embedding.py:33, in Embedding.create(cls, *args, **kwargs) 31 while True: 32 try: ---> 33 response = super().create(*args, **kwargs) 35 # If a user specifies base64, we'll just return the encoded string. 36 # This is only for the default case. 37 if not user_provided_encoding_format:

File /aml/train/lib/python3.8/site-packages/openai/api_resources/abstract/engine_api_resource.py:155, in EngineAPIResource.create(cls, api_key, api_base, api_type, request_id, api_version, organization, **params) 129 @classmethod 130 def create( 131 cls, ref='/aml/train/lib/python3.8/site-packages/openai/api_resources/abstract/engine_api_resource.py:0'>0;32m (...) 138 **params, 139 ): 140 ( 141 deployment_id, 142 engine, ref='/aml/train/lib/python3.8/site-packages/openai/api_resources/abstract/engine_api_resource.py:0'>0;32m (...) 152 api_key, api_base, api_type, api_version, organization, **params 153 ) --> 155 response, _, api_key = requestor.request( 156 "post", 157 url, 158 params=params, 159 headers=headers, 160 stream=stream, 161 request_id=request_id, 162 request_timeout=request_timeout, 163 ) 165 if stream: 166 # must be an iterator 167 assert not isinstance(response, OpenAIResponse)

File /aml/train/lib/python3.8/site-packages/openai/api_requestor.py:289, in APIRequestor.request(self, method, url, params, headers, files, stream, request_id, request_timeout) 278 def request( 279 self, 280 method, ref='/aml/train/lib/python3.8/site-packages/openai/api_requestor.py:0'>0;32m (...) 287 request_timeout: Optional[Union[float, Tuple[float, float]]] = None, 288 ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: --> 289 result = self.request_raw( 290 method.lower(), 291 url, 292 params=params, 293 supplied_headers=headers, 294 files=files, 295 stream=stream, 296 request_id=request_id, 297 request_timeout=request_timeout, 298 ) 299 resp, got_stream = self._interpret_response(result, stream) 300 return resp, got_stream, self.api_key

File /aml/train/lib/python3.8/site-packages/openai/api_requestor.py:606, in APIRequestor.request_raw(self, method, url, params, supplied_headers, files, stream, request_id, request_timeout) 604 _thread_context.session_create_time = time.time() 605 try: --> 606 result = _thread_context.session.request( 607 method, 608 abs_url, 609 headers=headers, 610 data=data, 611 files=files, 612 stream=stream, 613 timeout=request_timeout if request_timeout else TIMEOUT_SECS, 614 proxies=_thread_context.session.proxies, 615 ) 616 except requests.exceptions.Timeout as e: 617 raise error.Timeout("Request timed out: {}".format(e)) from e

File /aml/train/lib/python3.8/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json) 584 send_kwargs = { 585 "timeout": timeout, 586 "allow_redirects": allow_redirects, 587 } 588 send_kwargs.update(settings) --> 589 resp = self.send(prep, **send_kwargs) 591 return resp

File /aml/train/lib/python3.8/site-packages/requests/sessions.py:703, in Session.send(self, request, **kwargs) 700 start = preferred_clock() 702 # Send the request --> 703 r = adapter.send(request, **kwargs) 705 # Total elapsed time of the request (approximately) 706 elapsed = preferred_clock() - start

File /aml/train/lib/python3.8/site-packages/requests/adapters.py:486, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies) 483 timeout = TimeoutSauce(connect=timeout, read=timeout) 485 try: --> 486 resp = conn.urlopen( 487 method=request.method, 488 url=url, 489 body=request.body, 490 headers=request.headers, 491 redirect=False, 492 assert_same_host=False, 493 preload_content=False, 494 decode_content=False, 495 retries=self.max_retries, 496 timeout=timeout, 497 chunked=chunked, 498 ) 500 except (ProtocolError, OSError) as err: 501 raise ConnectionError(err, request=request)

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:826, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw) 821 if not conn: 822 # Try again 823 log.warning( 824 "Retrying (%r) after connection broken by '%r': %s", retries, err, url 825 ) --> 826 return self.urlopen( 827 method, 828 url, 829 body, 830 headers, 831 retries, 832 redirect, 833 assert_same_host, 834 timeout=timeout, 835 pool_timeout=pool_timeout, 836 release_conn=release_conn, 837 chunked=chunked, 838 body_pos=body_pos, 839 **response_kw 840 ) 842 # Handle redirect? 843 redirect_location = redirect and response.get_redirect_location()

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:714, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw) 711 self._prepare_proxy(conn) 713 # Make the request on the httplib connection object. --> 714 httplib_response = self._make_request( 715 conn, 716 method, 717 url, 718 timeout=timeout_obj, 719 body=body, 720 headers=headers, 721 chunked=chunked, 722 ) 724 # If we're going to release the connection in finally:, then 725 # the response doesn't need to know about the connection. Otherwise 726 # it will also try to release it and we'll have a double-release 727 # mess. 728 response_conn = conn if not release_conn else None

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:403, in HTTPConnectionPool._make_request(self, conn, method, url, timeout, chunked, **httplib_request_kw) 401 # Trigger any extra validation we need to do. 402 try: --> 403 self._validate_conn(conn) 404 except (SocketTimeout, BaseSSLError) as e: 405 # Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout. 406 self._raise_timeout(err=e, url=url, timeout_value=conn.timeout)

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:1053, in HTTPSConnectionPool._validate_conn(self, conn) 1051 # Force connect early to allow us to validate the connection. 1052 if not getattr(conn, "sock", None): # AppEngine might not have .sock -> 1053 conn.connect() 1055 if not conn.is_verified: 1056 warnings.warn( 1057 ( 1058 "Unverified HTTPS request is being made to host '%s'. " ref='/aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:0'>0;32m (...) 1063 InsecureRequestWarning, 1064 )

File /aml/train/lib/python3.8/site-packages/urllib3/connection.py:363, in HTTPSConnection.connect(self) 361 def connect(self): 362 # Add certificate verification --> 363 self.sock = conn = self._new_conn() 364 hostname = self.host 365 tls_in_tls = False

File /aml/train/lib/python3.8/site-packages/urllib3/connection.py:174, in HTTPConnection._new_conn(self) 171 extra_kw["socket_options"] = self.socket_options 173 try: --> 174 conn = connection.create_connection( 175 (self._dns_host, self.port), self.timeout, **extra_kw 176 ) 178 except SocketTimeout: 179 raise ConnectTimeoutError( 180 self, 181 "Connection to %s timed out. (connect timeout=%s)" 182 % (self.host, self.timeout), 183 )

File /aml/train/lib/python3.8/site-packages/urllib3/util/connection.py:85, in create_connection(address, timeout, source_address, socket_options) 83 if source_address: 84 sock.bind(source_address) ---> 85 sock.connect(sa) 86 return sock 88 except socket.error as e:

KeyboardInterrupt:

hellangleZ avatar Nov 01 '23 08:11 hellangleZ